Unverified Commit 401e1278 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4811)



* [Misc] clang-format auto fix.

* fix

* manual
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 6c53f351
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <ws2tcpip.h> #include <ws2tcpip.h>
#pragma comment(lib, "Ws2_32.lib") #pragma comment(lib, "Ws2_32.lib")
#else // !_WIN32 #else // !_WIN32
#include <sys/socket.h> #include <sys/socket.h>
#endif // _WIN32 #endif // _WIN32
#include <string> #include <string>
...@@ -20,7 +20,7 @@ namespace dgl { ...@@ -20,7 +20,7 @@ namespace dgl {
namespace network { namespace network {
/*! /*!
* \brief TCPSocket is a simple wrapper around a socket. * \brief TCPSocket is a simple wrapper around a socket.
* It supports only TCP connections. * It supports only TCP connections.
*/ */
class TCPSocket { class TCPSocket {
...@@ -32,7 +32,7 @@ class TCPSocket { ...@@ -32,7 +32,7 @@ class TCPSocket {
/*! /*!
* \brief TCPSocket deconstructor * \brief TCPSocket deconstructor
*/ */
~TCPSocket(); ~TCPSocket();
/*! /*!
...@@ -41,7 +41,7 @@ class TCPSocket { ...@@ -41,7 +41,7 @@ class TCPSocket {
* \param port end port * \param port end port
* \return true for success and false for failure * \return true for success and false for failure
*/ */
bool Connect(const char * ip, int port); bool Connect(const char* ip, int port);
/*! /*!
* \brief Bind on the given IP and PORT * \brief Bind on the given IP and PORT
...@@ -49,7 +49,7 @@ class TCPSocket { ...@@ -49,7 +49,7 @@ class TCPSocket {
* \param port end port * \param port end port
* \return true for success and false for failure * \return true for success and false for failure
*/ */
bool Bind(const char * ip, int port); bool Bind(const char* ip, int port);
/*! /*!
* \brief listen for remote connection * \brief listen for remote connection
...@@ -65,9 +65,7 @@ class TCPSocket { ...@@ -65,9 +65,7 @@ class TCPSocket {
* \param port_client new PORT will be stored to port_client * \param port_client new PORT will be stored to port_client
* \return true for success and false for failure * \return true for success and false for failure
*/ */
bool Accept(TCPSocket * socket, bool Accept(TCPSocket* socket, std::string* ip_client, int* port_client);
std::string * ip_client,
int * port_client);
/*! /*!
* \brief SetNonBlocking() is needed refering to this example of epoll: * \brief SetNonBlocking() is needed refering to this example of epoll:
...@@ -103,27 +101,27 @@ class TCPSocket { ...@@ -103,27 +101,27 @@ class TCPSocket {
* \param data data for sending * \param data data for sending
* \param len_data length of data * \param len_data length of data
* \return return number of bytes sent if OK, -1 on error * \return return number of bytes sent if OK, -1 on error
*/ */
int64_t Send(const char * data, int64_t len_data); int64_t Send(const char* data, int64_t len_data);
/*! /*!
* \brief Receive data. * \brief Receive data.
* \param buffer buffer for receving * \param buffer buffer for receving
* \param size_buffer size of buffer * \param size_buffer size of buffer
* \return return number of bytes received if OK, -1 on error * \return return number of bytes received if OK, -1 on error
*/ */
int64_t Receive(char * buffer, int64_t size_buffer); int64_t Receive(char* buffer, int64_t size_buffer);
/*! /*!
* \brief Get socket's file descriptor * \brief Get socket's file descriptor
* \return socket's file descriptor * \return socket's file descriptor
*/ */
int Socket() const; int Socket() const;
private: private:
/*! /*!
* \brief socket's file descriptor * \brief socket's file descriptor
*/ */
int socket_; int socket_;
}; };
......
This diff is collapsed.
...@@ -6,24 +6,25 @@ ...@@ -6,24 +6,25 @@
#ifndef DGL_RPC_RPC_H_ #ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_ #define DGL_RPC_RPC_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h> #include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <cstdint> #include <cstdint>
#include <memory>
#include <deque> #include <deque>
#include <vector> #include <memory>
#include <string>
#include <mutex> #include <mutex>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "./network/common.h"
#include "./rpc_msg.h" #include "./rpc_msg.h"
#include "./server_state.h"
#include "net_type.h" #include "net_type.h"
#include "network/socket_communicator.h" #include "network/socket_communicator.h"
#include "tensorpipe/tp_communicator.h" #include "tensorpipe/tp_communicator.h"
#include "./network/common.h"
#include "./server_state.h"
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
...@@ -138,7 +139,7 @@ struct RPCContext { ...@@ -138,7 +139,7 @@ struct RPCContext {
} }
int32_t RegisterClient(int32_t client_id, int32_t group_id) { int32_t RegisterClient(int32_t client_id, int32_t group_id) {
auto &&m = clients_[group_id]; auto&& m = clients_[group_id];
if (m.find(client_id) != m.end()) { if (m.find(client_id) != m.end()) {
return -1; return -1;
} }
...@@ -150,7 +151,7 @@ struct RPCContext { ...@@ -150,7 +151,7 @@ struct RPCContext {
if (clients_.find(group_id) == clients_.end()) { if (clients_.find(group_id) == clients_.end()) {
return -1; return -1;
} }
const auto &m = clients_.at(group_id); const auto& m = clients_.at(group_id);
if (m.find(client_id) == m.end()) { if (m.find(client_id) == m.end()) {
return -1; return -1;
} }
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
#ifndef DGL_RPC_RPC_MSG_H_ #ifndef DGL_RPC_RPC_MSG_H_
#define DGL_RPC_RPC_MSG_H_ #define DGL_RPC_RPC_MSG_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dgl/zerocopy_serializer.h> #include <dgl/zerocopy_serializer.h>
#include <string> #include <string>
......
...@@ -7,11 +7,12 @@ ...@@ -7,11 +7,12 @@
#ifndef DGL_RPC_SERVER_STATE_H_ #ifndef DGL_RPC_SERVER_STATE_H_
#define DGL_RPC_SERVER_STATE_H_ #define DGL_RPC_SERVER_STATE_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <unordered_map> #include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <string> #include <string>
#include <unordered_map>
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
......
...@@ -9,10 +9,11 @@ ...@@ -9,10 +9,11 @@
#define DGL_RPC_TENSORPIPE_QUEUE_H_ #define DGL_RPC_TENSORPIPE_QUEUE_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <chrono>
#include <condition_variable> #include <condition_variable>
#include <deque> #include <deque>
#include <mutex> #include <mutex>
#include <chrono>
#include <utility> #include <utility>
namespace dgl { namespace dgl {
...@@ -39,8 +40,9 @@ class Queue { ...@@ -39,8 +40,9 @@ class Queue {
DLOG(WARNING) << "Will wait infinitely until message is popped..."; DLOG(WARNING) << "Will wait infinitely until message is popped...";
cv_.wait(lock, [this] { return items_.size() > 0; }); cv_.wait(lock, [this] { return items_.size() > 0; });
} else { } else {
if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout), if (!cv_.wait_for(lock, std::chrono::milliseconds(timeout), [this] {
[this] { return items_.size() > 0; })) { return items_.size() > 0;
})) {
DLOG(WARNING) << "Times out for popping message after " << timeout DLOG(WARNING) << "Times out for popping message after " << timeout
<< " milliseconds."; << " milliseconds.";
return false; return false;
......
...@@ -48,8 +48,8 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) { ...@@ -48,8 +48,8 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true); StreamWithBuffer zc_write_strm(zerocopy_blob_ptr, true);
zc_write_strm.Write(msg); zc_write_strm.Write(msg);
int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size(); int32_t nonempty_ndarray_count = zc_write_strm.buffer_list().size();
zerocopy_blob_ptr->append(reinterpret_cast<char *>(&nonempty_ndarray_count), zerocopy_blob_ptr->append(
sizeof(int32_t)); reinterpret_cast<char *>(&nonempty_ndarray_count), sizeof(int32_t));
tp_msg.tensors.resize(nonempty_ndarray_count); tp_msg.tensors.resize(nonempty_ndarray_count);
// Hold the NDArray that ensure it's valid until write operation completes // Hold the NDArray that ensure it's valid until write operation completes
auto ndarray_holder = std::make_shared<std::vector<NDArray>>(); auto ndarray_holder = std::make_shared<std::vector<NDArray>>();
...@@ -68,14 +68,14 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) { ...@@ -68,14 +68,14 @@ void TPSender::Send(const RPCMessage &msg, int recv_id) {
} }
// Let's write blockingly in case of congestion in underlying transports. // Let's write blockingly in case of congestion in underlying transports.
auto done = std::make_shared<std::promise<void>>(); auto done = std::make_shared<std::promise<void>>();
pipe->write(tp_msg, pipe->write(
[ndarray_holder, recv_id, done](const tensorpipe::Error &error) { tp_msg, [ndarray_holder, recv_id, done](const tensorpipe::Error &error) {
if (error) { if (error) {
LOG(FATAL) << "Failed to send message to " << recv_id LOG(FATAL) << "Failed to send message to " << recv_id
<< ". Details: " << error.what(); << ". Details: " << error.what();
} }
done->set_value(); done->set_value();
}); });
done->get_future().wait(); done->get_future().wait();
} }
...@@ -120,7 +120,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) { ...@@ -120,7 +120,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
if (error.isOfType<ListenerClosedError>()) { if (error.isOfType<ListenerClosedError>()) {
// Expected. // Expected.
} else { } else {
LOG(WARNING) << "Unexpected error when accepting incoming pipe: " << error.what(); LOG(WARNING) << "Unexpected error when accepting incoming pipe: "
<< error.what();
} }
return; return;
} }
...@@ -133,7 +134,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) { ...@@ -133,7 +134,8 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
// read the handshake message: "dglconnect" // read the handshake message: "dglconnect"
pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) { pipe->readDescriptor([pipe, this](const Error &error, Descriptor descriptor) {
if (error) { if (error) {
LOG(ERROR) << "Unexpected error when reading from accepted pipe: " << error.what(); LOG(ERROR) << "Unexpected error when reading from accepted pipe: "
<< error.what();
return; return;
} }
Allocation allocation; Allocation allocation;
...@@ -145,10 +147,10 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) { ...@@ -145,10 +147,10 @@ void TPReceiver::OnAccepted(const Error &error, std::shared_ptr<Pipe> pipe) {
}); });
} }
void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe, void TPReceiver::ReceiveFromPipe(
std::shared_ptr<RPCMessageQueue> queue) { std::shared_ptr<Pipe> pipe, std::shared_ptr<RPCMessageQueue> queue) {
pipe->readDescriptor([pipe, queue = std::move(queue)](const Error &error, pipe->readDescriptor([pipe, queue = std::move(queue)](
Descriptor descriptor) { const Error &error, Descriptor descriptor) {
if (error) { if (error) {
// Error may happen when the pipe is closed // Error may happen when the pipe is closed
return; return;
...@@ -165,31 +167,33 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe, ...@@ -165,31 +167,33 @@ void TPReceiver::ReceiveFromPipe(std::shared_ptr<Pipe> pipe,
allocation.tensors[i].buffer = cpu_buffer; allocation.tensors[i].buffer = cpu_buffer;
} }
} }
pipe->read(allocation, [allocation, descriptor = std::move(descriptor), pipe->read(
queue = std::move(queue), allocation, [allocation, descriptor = std::move(descriptor),
pipe](const Error &error) { queue = std::move(queue), pipe](const Error &error) {
if (error) { if (error) {
// Because we always have a read event posted to the epoll, // Because we always have a read event posted to the epoll,
// Therefore when pipe is closed, error will be raised. // Therefore when pipe is closed, error will be raised.
// But this error is expected. // But this error is expected.
// Other error is not expected. But we cannot identify the error with // Other error is not expected. But we cannot identify the error
// each Other for now. Thus here we skip handling for all errors // with each Other for now. Thus here we skip handling for all
return; // errors
} return;
}
char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
std::vector<void *> buffer_list(descriptor.tensors.size()); char *meta_msg_begin = const_cast<char *>(&descriptor.metadata[0]);
for (size_t i = 0; i < descriptor.tensors.size(); i++) { std::vector<void *> buffer_list(descriptor.tensors.size());
buffer_list[i] = allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr; for (size_t i = 0; i < descriptor.tensors.size(); i++) {
} buffer_list[i] =
StreamWithBuffer zc_read_strm( allocation.tensors[i].buffer.unwrap<CpuBuffer>().ptr;
meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t), }
buffer_list); StreamWithBuffer zc_read_strm(
RPCMessage msg; meta_msg_begin, descriptor.metadata.size() - sizeof(int32_t),
zc_read_strm.Read(&msg); buffer_list);
queue->push(msg); RPCMessage msg;
TPReceiver::ReceiveFromPipe(pipe, queue); zc_read_strm.Read(&msg);
}); queue->push(msg);
TPReceiver::ReceiveFromPipe(pipe, queue);
});
}); });
} }
......
...@@ -9,15 +9,16 @@ ...@@ -9,15 +9,16 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tensorpipe/tensorpipe.h> #include <tensorpipe/tensorpipe.h>
#include <atomic>
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <string> #include <string>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <atomic>
#include "./queue.h"
#include "../net_type.h" #include "../net_type.h"
#include "./queue.h"
namespace dgl { namespace dgl {
namespace rpc { namespace rpc {
...@@ -47,11 +48,12 @@ class TPSender : public RPCSender { ...@@ -47,11 +48,12 @@ class TPSender : public RPCSender {
/*! /*!
* \brief Connect to a receiver. * \brief Connect to a receiver.
* *
* When there are multiple receivers to be connected, application will call `ConnectReceiver` * When there are multiple receivers to be connected, application will call
* for each and then call `ConnectReceiverFinalize` to make sure that either all the connections are * `ConnectReceiver` for each and then call `ConnectReceiverFinalize` to make
* successfully established or some of them fail. * sure that either all the connections are successfully established or some
* * of them fail.
*
* \param addr Networking address, e.g., 'tcp://127.0.0.1:50091' * \param addr Networking address, e.g., 'tcp://127.0.0.1:50091'
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return True for success and False for fail * \return True for success and False for fail
...@@ -75,7 +77,7 @@ class TPSender : public RPCSender { ...@@ -75,7 +77,7 @@ class TPSender : public RPCSender {
/*! /*!
* \brief Communicator type: 'tp' * \brief Communicator type: 'tp'
*/ */
const std::string &NetType() const override { const std::string& NetType() const override {
static const std::string net_type = "tensorpipe"; static const std::string net_type = "tensorpipe";
return net_type; return net_type;
} }
...@@ -90,7 +92,7 @@ class TPSender : public RPCSender { ...@@ -90,7 +92,7 @@ class TPSender : public RPCSender {
* \brief pipe for each connection of receiver * \brief pipe for each connection of receiver
*/ */
std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>> std::unordered_map<int /* receiver ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_; pipes_;
/*! /*!
* \brief receivers' listening address * \brief receivers' listening address
...@@ -129,13 +131,14 @@ class TPReceiver : public RPCReceiver { ...@@ -129,13 +131,14 @@ class TPReceiver : public RPCReceiver {
* *
* Wait() is not thread-safe and only one thread can invoke this API. * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
bool Wait(const std::string &addr, int num_sender, bool Wait(
bool blocking = true) override; const std::string& addr, int num_sender, bool blocking = true) override;
/*! /*!
* \brief Recv RPCMessage from Sender. Actually removing data from queue. * \brief Recv RPCMessage from Sender. Actually removing data from queue.
* \param msg pointer of RPCmessage * \param msg pointer of RPCmessage
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely. * \param timeout The timeout value in milliseconds. If zero, wait
* indefinitely.
* \return RPCStatus: kRPCSuccess or kRPCTimeOut. * \return RPCStatus: kRPCSuccess or kRPCTimeOut.
*/ */
RPCStatus Recv(RPCMessage* msg, int timeout) override; RPCStatus Recv(RPCMessage* msg, int timeout) override;
...@@ -150,7 +153,7 @@ class TPReceiver : public RPCReceiver { ...@@ -150,7 +153,7 @@ class TPReceiver : public RPCReceiver {
/*! /*!
* \brief Communicator type: 'tp' (tensorpipe) * \brief Communicator type: 'tp' (tensorpipe)
*/ */
const std::string &NetType() const override { const std::string& NetType() const override {
static const std::string net_type = "tensorpipe"; static const std::string net_type = "tensorpipe";
return net_type; return net_type;
} }
...@@ -158,8 +161,9 @@ class TPReceiver : public RPCReceiver { ...@@ -158,8 +161,9 @@ class TPReceiver : public RPCReceiver {
/*! /*!
* \brief Issue a receive request on pipe, and push the result into queue * \brief Issue a receive request on pipe, and push the result into queue
*/ */
static void ReceiveFromPipe(std::shared_ptr<tensorpipe::Pipe> pipe, static void ReceiveFromPipe(
std::shared_ptr<RPCMessageQueue> queue); std::shared_ptr<tensorpipe::Pipe> pipe,
std::shared_ptr<RPCMessageQueue> queue);
private: private:
/*! /*!
...@@ -186,9 +190,9 @@ class TPReceiver : public RPCReceiver { ...@@ -186,9 +190,9 @@ class TPReceiver : public RPCReceiver {
/*! /*!
* \brief pipe for each client connections * \brief pipe for each client connections
*/ */
std::unordered_map<int /* Sender (virutal) ID */, std::unordered_map<
std::shared_ptr<tensorpipe::Pipe>> int /* Sender (virutal) ID */, std::shared_ptr<tensorpipe::Pipe>>
pipes_; pipes_;
/*! /*!
* \brief RPCMessage queue * \brief RPCMessage queue
......
...@@ -3,16 +3,18 @@ ...@@ -3,16 +3,18 @@
* Implementation of C API (reference: tvm/src/api/c_api.cc) * Implementation of C API (reference: tvm/src/api/c_api.cc)
* \file c_api.cc * \file c_api.cc
*/ */
#include <dgl/packed_func_ext.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/object.h>
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/c_object_api.h>
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
#include <vector>
#include <string>
#include <exception> #include <exception>
#include <string>
#include <vector>
#include "runtime_base.h" #include "runtime_base.h"
/*! \brief entry to to easily hold returning information */ /*! \brief entry to to easily hold returning information */
...@@ -20,7 +22,7 @@ struct DGLAPIThreadLocalEntry { ...@@ -20,7 +22,7 @@ struct DGLAPIThreadLocalEntry {
/*! \brief result holder for returning strings */ /*! \brief result holder for returning strings */
std::vector<std::string> ret_vec_str; std::vector<std::string> ret_vec_str;
/*! \brief result holder for returning string pointers */ /*! \brief result holder for returning string pointers */
std::vector<const char *> ret_vec_charp; std::vector<const char*> ret_vec_charp;
/*! \brief result holder for retruning string */ /*! \brief result holder for retruning string */
std::string ret_str; std::string ret_str;
}; };
...@@ -44,7 +46,8 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -44,7 +46,8 @@ struct APIAttrGetter : public AttrVisitor {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
} }
void Visit(const char* key, uint64_t* value) final { void Visit(const char* key, uint64_t* value) final {
CHECK_LE(value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) CHECK_LE(
value[0], static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
<< "cannot return too big constant"; << "cannot return too big constant";
if (skey == key) *ret = static_cast<int64_t>(value[0]); if (skey == key) *ret = static_cast<int64_t>(value[0]);
} }
...@@ -71,30 +74,16 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -71,30 +74,16 @@ struct APIAttrGetter : public AttrVisitor {
struct APIAttrDir : public AttrVisitor { struct APIAttrDir : public AttrVisitor {
std::vector<std::string>* names; std::vector<std::string>* names;
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final { names->push_back(key); }
names->push_back(key); void Visit(const char* key, int64_t* value) final { names->push_back(key); }
} void Visit(const char* key, uint64_t* value) final { names->push_back(key); }
void Visit(const char* key, int64_t* value) final { void Visit(const char* key, bool* value) final { names->push_back(key); }
names->push_back(key); void Visit(const char* key, int* value) final { names->push_back(key); }
}
void Visit(const char* key, uint64_t* value) final {
names->push_back(key);
}
void Visit(const char* key, bool* value) final {
names->push_back(key);
}
void Visit(const char* key, int* value) final {
names->push_back(key);
}
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, ObjectRef* value) final { void Visit(const char* key, ObjectRef* value) final { names->push_back(key); }
names->push_back(key); void Visit(const char* key, NDArray* value) final { names->push_back(key); }
}
void Visit(const char* key, NDArray* value) final {
names->push_back(key);
}
}; };
int DGLObjectFree(ObjectHandle handle) { int DGLObjectFree(ObjectHandle handle) {
...@@ -103,26 +92,22 @@ int DGLObjectFree(ObjectHandle handle) { ...@@ -103,26 +92,22 @@ int DGLObjectFree(ObjectHandle handle) {
API_END(); API_END();
} }
int DGLObjectTypeKey2Index(const char* type_key, int DGLObjectTypeKey2Index(const char* type_key, int* out_index) {
int* out_index) {
API_BEGIN(); API_BEGIN();
*out_index = static_cast<int>(Object::TypeKey2Index(type_key)); *out_index = static_cast<int>(Object::TypeKey2Index(type_key));
API_END(); API_END();
} }
int DGLObjectGetTypeIndex(ObjectHandle handle, int DGLObjectGetTypeIndex(ObjectHandle handle, int* out_index) {
int* out_index) {
API_BEGIN(); API_BEGIN();
*out_index = static_cast<int>( *out_index =
(*static_cast<DGLAPIObject*>(handle))->type_index()); static_cast<int>((*static_cast<DGLAPIObject*>(handle))->type_index());
API_END(); API_END();
} }
int DGLObjectGetAttr(ObjectHandle handle, int DGLObjectGetAttr(
const char* key, ObjectHandle handle, const char* key, DGLValue* ret_val, int* ret_type_code,
DGLValue* ret_val, int* ret_success) {
int* ret_type_code,
int* ret_success) {
API_BEGIN(); API_BEGIN();
DGLRetValue rv; DGLRetValue rv;
APIAttrGetter getter; APIAttrGetter getter;
...@@ -136,9 +121,8 @@ int DGLObjectGetAttr(ObjectHandle handle, ...@@ -136,9 +121,8 @@ int DGLObjectGetAttr(ObjectHandle handle,
} else { } else {
(*tobject)->VisitAttrs(&getter); (*tobject)->VisitAttrs(&getter);
*ret_success = getter.found_object_ref || rv.type_code() != kNull; *ret_success = getter.found_object_ref || rv.type_code() != kNull;
if (rv.type_code() == kStr || if (rv.type_code() == kStr || rv.type_code() == kDGLDataType) {
rv.type_code() == kDGLDataType) { DGLAPIThreadLocalEntry* e = DGLAPIThreadLocalStore::Get();
DGLAPIThreadLocalEntry *e = DGLAPIThreadLocalStore::Get();
e->ret_str = rv.operator std::string(); e->ret_str = rv.operator std::string();
*ret_type_code = kStr; *ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str(); ret_val->v_str = e->ret_str.c_str();
...@@ -149,10 +133,9 @@ int DGLObjectGetAttr(ObjectHandle handle, ...@@ -149,10 +133,9 @@ int DGLObjectGetAttr(ObjectHandle handle,
API_END(); API_END();
} }
int DGLObjectListAttrNames(ObjectHandle handle, int DGLObjectListAttrNames(
int *out_size, ObjectHandle handle, int* out_size, const char*** out_array) {
const char*** out_array) { DGLAPIThreadLocalEntry* ret = DGLAPIThreadLocalStore::Get();
DGLAPIThreadLocalEntry *ret = DGLAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_vec_str.clear(); ret->ret_vec_str.clear();
DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle); DGLAPIObject* tobject = static_cast<DGLAPIObject*>(handle);
......
...@@ -3,18 +3,20 @@ ...@@ -3,18 +3,20 @@
* \file c_runtime_api.cc * \file c_runtime_api.cc
* \brief Runtime API implementation * \brief Runtime API implementation
*/ */
#include <dmlc/thread_local.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/c_backend_api.h> #include <dgl/runtime/c_backend_api.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h> #include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <array> #include <dmlc/thread_local.h>
#include <algorithm> #include <algorithm>
#include <string> #include <array>
#include <cstdlib> #include <cstdlib>
#include <string>
#include "runtime_base.h" #include "runtime_base.h"
namespace dgl { namespace dgl {
...@@ -26,10 +28,14 @@ namespace runtime { ...@@ -26,10 +28,14 @@ namespace runtime {
*/ */
inline std::string DeviceName(int type) { inline std::string DeviceName(int type) {
switch (type) { switch (type) {
case kDGLCPU: return "cpu"; case kDGLCPU:
case kDGLCUDA: return "cuda"; return "cpu";
case kDGLCUDA:
return "cuda";
// add more device here once supported // add more device here once supported
default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; default:
LOG(FATAL) << "unknown type =" << type;
return "Unknown";
} }
} }
...@@ -37,9 +43,7 @@ class DeviceAPIManager { ...@@ -37,9 +43,7 @@ class DeviceAPIManager {
public: public:
static const int kMaxDeviceAPI = 32; static const int kMaxDeviceAPI = 32;
// Get API // Get API
static DeviceAPI* Get(const DGLContext& ctx) { static DeviceAPI* Get(const DGLContext& ctx) { return Get(ctx.device_type); }
return Get(ctx.device_type);
}
static DeviceAPI* Get(int dev_type, bool allow_missing = false) { static DeviceAPI* Get(int dev_type, bool allow_missing = false) {
return Global()->GetAPI(dev_type, allow_missing); return Global()->GetAPI(dev_type, allow_missing);
} }
...@@ -49,9 +53,7 @@ class DeviceAPIManager { ...@@ -49,9 +53,7 @@ class DeviceAPIManager {
DeviceAPI* rpc_api_{nullptr}; DeviceAPI* rpc_api_{nullptr};
std::mutex mutex_; std::mutex mutex_;
// constructor // constructor
DeviceAPIManager() { DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable. // Global static variable.
static DeviceAPIManager* Global() { static DeviceAPIManager* Global() {
static DeviceAPIManager inst; static DeviceAPIManager inst;
...@@ -78,7 +80,8 @@ class DeviceAPIManager { ...@@ -78,7 +80,8 @@ class DeviceAPIManager {
auto* f = Registry::Get(factory); auto* f = Registry::Get(factory);
if (f == nullptr) { if (f == nullptr) {
CHECK(allow_missing) CHECK(allow_missing)
<< "Device API " << name << " is not enabled. Please install the cuda version of dgl."; << "Device API " << name
<< " is not enabled. Please install the cuda version of dgl.";
return nullptr; return nullptr;
} }
void* ptr = (*f)(); void* ptr = (*f)();
...@@ -95,9 +98,8 @@ DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) { ...@@ -95,9 +98,8 @@ DeviceAPI* DeviceAPI::Get(DGLDeviceType dev_type, bool allow_missing) {
return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing); return DeviceAPIManager::Get(static_cast<int>(dev_type), allow_missing);
} }
void* DeviceAPI::AllocWorkspace(DGLContext ctx, void* DeviceAPI::AllocWorkspace(
size_t size, DGLContext ctx, size_t size, DGLDataType type_hint) {
DGLDataType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
} }
...@@ -114,9 +116,8 @@ void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) { ...@@ -114,9 +116,8 @@ void DeviceAPI::FreeStream(DGLContext ctx, DGLStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
void DeviceAPI::SyncStreamFromTo(DGLContext ctx, void DeviceAPI::SyncStreamFromTo(
DGLStreamHandle event_src, DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
DGLStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
...@@ -140,7 +141,7 @@ struct DGLRuntimeEntry { ...@@ -140,7 +141,7 @@ struct DGLRuntimeEntry {
typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore; typedef dmlc::ThreadLocalStore<DGLRuntimeEntry> DGLAPIRuntimeStore;
const char *DGLGetLastError() { const char* DGLGetLastError() {
return DGLAPIRuntimeStore::Get()->last_error.c_str(); return DGLAPIRuntimeStore::Get()->last_error.c_str();
} }
...@@ -152,30 +153,26 @@ void DGLAPISetLastError(const char* msg) { ...@@ -152,30 +153,26 @@ void DGLAPISetLastError(const char* msg) {
#endif #endif
} }
int DGLModLoadFromFile(const char* file_name, int DGLModLoadFromFile(
const char* format, const char* file_name, const char* format, DGLModuleHandle* out) {
DGLModuleHandle* out) {
API_BEGIN(); API_BEGIN();
Module m = Module::LoadFromFile(file_name, format); Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m); *out = new Module(m);
API_END(); API_END();
} }
int DGLModImport(DGLModuleHandle mod, int DGLModImport(DGLModuleHandle mod, DGLModuleHandle dep) {
DGLModuleHandle dep) {
API_BEGIN(); API_BEGIN();
static_cast<Module*>(mod)->Import( static_cast<Module*>(mod)->Import(*static_cast<Module*>(dep));
*static_cast<Module*>(dep));
API_END(); API_END();
} }
int DGLModGetFunction(DGLModuleHandle mod, int DGLModGetFunction(
const char* func_name, DGLModuleHandle mod, const char* func_name, int query_imports,
int query_imports, DGLFunctionHandle* func) {
DGLFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction( PackedFunc pf =
func_name, query_imports != 0); static_cast<Module*>(mod)->GetFunction(func_name, query_imports != 0);
if (pf != nullptr) { if (pf != nullptr) {
*func = new PackedFunc(pf); *func = new PackedFunc(pf);
} else { } else {
...@@ -190,20 +187,18 @@ int DGLModFree(DGLModuleHandle mod) { ...@@ -190,20 +187,18 @@ int DGLModFree(DGLModuleHandle mod) {
API_END(); API_END();
} }
int DGLBackendGetFuncFromEnv(void* mod_node, int DGLBackendGetFuncFromEnv(
const char* func_name, void* mod_node, const char* func_name, DGLFunctionHandle* func) {
DGLFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
*func = (DGLFunctionHandle)( *func =
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name)); (DGLFunctionHandle)(static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(
func_name));
API_END(); API_END();
} }
void* DGLBackendAllocWorkspace(int device_type, void* DGLBackendAllocWorkspace(
int device_id, int device_type, int device_id, uint64_t size, int dtype_code_hint,
uint64_t size, int dtype_bits_hint) {
int dtype_code_hint,
int dtype_bits_hint) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
...@@ -213,14 +208,11 @@ void* DGLBackendAllocWorkspace(int device_type, ...@@ -213,14 +208,11 @@ void* DGLBackendAllocWorkspace(int device_type,
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1; type_hint.lanes = 1;
return DeviceAPIManager::Get(ctx)->AllocWorkspace(ctx, return DeviceAPIManager::Get(ctx)->AllocWorkspace(
static_cast<size_t>(size), ctx, static_cast<size_t>(size), type_hint);
type_hint);
} }
int DGLBackendFreeWorkspace(int device_type, int DGLBackendFreeWorkspace(int device_type, int device_id, void* ptr) {
int device_id,
void* ptr) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
...@@ -228,10 +220,7 @@ int DGLBackendFreeWorkspace(int device_type, ...@@ -228,10 +220,7 @@ int DGLBackendFreeWorkspace(int device_type,
return 0; return 0;
} }
int DGLBackendRunOnce(void** handle, int DGLBackendRunOnce(void** handle, int (*f)(void*), void* cdata, int nbytes) {
int (*f)(void*),
void* cdata,
int nbytes) {
if (*handle == nullptr) { if (*handle == nullptr) {
*handle = reinterpret_cast<void*>(1); *handle = reinterpret_cast<void*>(1);
return (*f)(cdata); return (*f)(cdata);
...@@ -245,19 +234,15 @@ int DGLFuncFree(DGLFunctionHandle func) { ...@@ -245,19 +234,15 @@ int DGLFuncFree(DGLFunctionHandle func) {
API_END(); API_END();
} }
int DGLFuncCall(DGLFunctionHandle func, int DGLFuncCall(
DGLValue* args, DGLFunctionHandle func, DGLValue* args, int* arg_type_codes, int num_args,
int* arg_type_codes, DGLValue* ret_val, int* ret_type_code) {
int num_args,
DGLValue* ret_val,
int* ret_type_code) {
API_BEGIN(); API_BEGIN();
DGLRetValue rv; DGLRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked( (*static_cast<const PackedFunc*>(func))
DGLArgs(args, arg_type_codes, num_args), &rv); .CallPacked(DGLArgs(args, arg_type_codes, num_args), &rv);
// handle return string. // handle return string.
if (rv.type_code() == kStr || if (rv.type_code() == kStr || rv.type_code() == kDGLDataType ||
rv.type_code() == kDGLDataType ||
rv.type_code() == kBytes) { rv.type_code() == kBytes) {
DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get(); DGLRuntimeEntry* e = DGLAPIRuntimeStore::Get();
if (rv.type_code() != kDGLDataType) { if (rv.type_code() != kDGLDataType) {
...@@ -280,10 +265,8 @@ int DGLFuncCall(DGLFunctionHandle func, ...@@ -280,10 +265,8 @@ int DGLFuncCall(DGLFunctionHandle func,
API_END(); API_END();
} }
int DGLCFuncSetReturn(DGLRetValueHandle ret, int DGLCFuncSetReturn(
DGLValue* value, DGLRetValueHandle ret, DGLValue* value, int* type_code, int num_ret) {
int* type_code,
int num_ret) {
API_BEGIN(); API_BEGIN();
CHECK_EQ(num_ret, 1); CHECK_EQ(num_ret, 1);
DGLRetValue* rv = static_cast<DGLRetValue*>(ret); DGLRetValue* rv = static_cast<DGLRetValue*>(ret);
...@@ -291,16 +274,16 @@ int DGLCFuncSetReturn(DGLRetValueHandle ret, ...@@ -291,16 +274,16 @@ int DGLCFuncSetReturn(DGLRetValueHandle ret,
API_END(); API_END();
} }
int DGLFuncCreateFromCFunc(DGLPackedCFunc func, int DGLFuncCreateFromCFunc(
void* resource_handle, DGLPackedCFunc func, void* resource_handle, DGLPackedCFuncFinalizer fin,
DGLPackedCFuncFinalizer fin, DGLFunctionHandle* out) {
DGLFunctionHandle *out) {
API_BEGIN(); API_BEGIN();
if (fin == nullptr) { if (fin == nullptr) {
*out = new PackedFunc( *out =
[func, resource_handle](DGLArgs args, DGLRetValue* rv) { new PackedFunc([func, resource_handle](DGLArgs args, DGLRetValue* rv) {
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*) int ret = func(
args.num_args, rv, resource_handle); (DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) { if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n"; std::string err = "DGLCall CFunc Error:\n";
err += DGLGetLastError(); err += DGLGetLastError();
...@@ -311,16 +294,16 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func, ...@@ -311,16 +294,16 @@ int DGLFuncCreateFromCFunc(DGLPackedCFunc func,
// wrap it in a shared_ptr, with fin as deleter. // wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope. // so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin); std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc( *out = new PackedFunc([func, rpack](DGLArgs args, DGLRetValue* rv) {
[func, rpack](DGLArgs args, DGLRetValue* rv) { int ret = func(
int ret = func((DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*) (DGLValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get()); args.num_args, rv, rpack.get());
if (ret != 0) { if (ret != 0) {
std::string err = "DGLCall CFunc Error:\n"; std::string err = "DGLCall CFunc Error:\n";
err += DGLGetLastError(); err += DGLGetLastError();
throw dmlc::Error(err); throw dmlc::Error(err);
} }
}); });
} }
API_END(); API_END();
} }
...@@ -370,10 +353,8 @@ int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) { ...@@ -370,10 +353,8 @@ int DGLSynchronize(int device_type, int device_id, DGLStreamHandle stream) {
API_END(); API_END();
} }
int DGLStreamStreamSynchronize(int device_type, int DGLStreamStreamSynchronize(
int device_id, int device_type, int device_id, DGLStreamHandle src, DGLStreamHandle dst) {
DGLStreamHandle src,
DGLStreamHandle dst) {
API_BEGIN(); API_BEGIN();
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(device_type); ctx.device_type = static_cast<DGLDeviceType>(device_type);
...@@ -392,36 +373,35 @@ int DGLCbArgToReturn(DGLValue* value, int code) { ...@@ -392,36 +373,35 @@ int DGLCbArgToReturn(DGLValue* value, int code) {
API_END(); API_END();
} }
int DGLLoadTensorAdapter(const char *path) { int DGLLoadTensorAdapter(const char* path) {
return TensorDispatcher::Global()->Load(path) ? 0 : -1; return TensorDispatcher::Global()->Load(path) ? 0 : -1;
} }
// set device api // set device api
DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device) DGL_REGISTER_GLOBAL(dgl::runtime::symbol::dgl_set_device)
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
DeviceAPIManager::Get(ctx)->SetDevice(ctx); DeviceAPIManager::Get(ctx)->SetDevice(ctx);
}); });
// set device api // set device api
DGL_REGISTER_GLOBAL("_GetDeviceAttr") DGL_REGISTER_GLOBAL("_GetDeviceAttr")
.set_body([](DGLArgs args, DGLRetValue *ret) { .set_body([](DGLArgs args, DGLRetValue* ret) {
DGLContext ctx; DGLContext ctx;
ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int()); ctx.device_type = static_cast<DGLDeviceType>(args[0].operator int());
ctx.device_id = args[1]; ctx.device_id = args[1];
DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int()); DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[2].operator int());
if (kind == kExist) { if (kind == kExist) {
DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true); DeviceAPI* api = DeviceAPIManager::Get(ctx.device_type, true);
if (api != nullptr) { if (api != nullptr) {
api->GetAttr(ctx, kind, ret); api->GetAttr(ctx, kind, ret);
} else {
*ret = 0;
}
} else { } else {
*ret = 0; DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
} }
} else { });
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
}
});
...@@ -4,32 +4,28 @@ ...@@ -4,32 +4,28 @@
* \brief DGL runtime config * \brief DGL runtime config
*/ */
#include <dgl/runtime/registry.h>
#include <dgl/runtime/config.h> #include <dgl/runtime/config.h>
#include <dgl/runtime/registry.h>
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
void Config::EnableLibxsmm(bool b) { void Config::EnableLibxsmm(bool b) { libxsmm_ = b; }
libxsmm_ = b;
}
bool Config::IsLibxsmmAvailable() const { bool Config::IsLibxsmmAvailable() const { return libxsmm_; }
return libxsmm_;
}
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigSetLibxsmm") DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigSetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
bool use_libxsmm = args[0]; bool use_libxsmm = args[0];
dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm); dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm);
}); });
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigGetLibxsmm") DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigGetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
*rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable(); *rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable();
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
* Copyright (c) 2016-2022 by Contributors * Copyright (c) 2016-2022 by Contributors
* \file cpu_device_api.cc * \file cpu_device_api.cc
*/ */
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "workspace_pool.h" #include "workspace_pool.h"
namespace dgl { namespace dgl {
...@@ -21,13 +23,11 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -21,13 +23,11 @@ class CPUDeviceAPI final : public DeviceAPI {
*rv = 1; *rv = 1;
} }
} }
void* AllocDataSpace(DGLContext ctx, void* AllocDataSpace(
size_t nbytes, DGLContext ctx, size_t nbytes, size_t alignment,
size_t alignment, DGLDataType type_hint) final {
DGLDataType type_hint) final {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUAllocWorkspace(nbytes);
return td->CPUAllocWorkspace(nbytes);
void* ptr; void* ptr;
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
...@@ -45,8 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -45,8 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final { void FreeDataSpace(DGLContext ctx, void* ptr) final {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUFreeWorkspace(ptr);
return td->CPUFreeWorkspace(ptr);
#if _MSC_VER || defined(__MINGW32__) #if _MSC_VER || defined(__MINGW32__)
_aligned_free(ptr); _aligned_free(ptr);
...@@ -55,25 +54,21 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -55,25 +54,21 @@ class CPUDeviceAPI final : public DeviceAPI {
#endif #endif
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(
size_t from_offset, const void* from, size_t from_offset, void* to, size_t to_offset,
void* to, size_t size, DGLContext ctx_from, DGLContext ctx_to,
size_t to_offset, DGLDataType type_hint) final {
size_t size, memcpy(
DGLContext ctx_from, static_cast<char*>(to) + to_offset,
DGLContext ctx_to, static_cast<const char*>(from) + from_offset, size);
DGLDataType type_hint) final {
memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset,
size);
} }
DGLStreamHandle CreateStream(DGLContext) final { return nullptr; } DGLStreamHandle CreateStream(DGLContext) final { return nullptr; }
void StreamSync(DGLContext ctx, DGLStreamHandle stream) final { void StreamSync(DGLContext ctx, DGLStreamHandle stream) final {}
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final; void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final;
void FreeWorkspace(DGLContext ctx, void* data) final; void FreeWorkspace(DGLContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() { static const std::shared_ptr<CPUDeviceAPI>& Global() {
...@@ -84,32 +79,29 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -84,32 +79,29 @@ class CPUDeviceAPI final : public DeviceAPI {
}; };
struct CPUWorkspacePool : public WorkspacePool { struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() : CPUWorkspacePool() : WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
WorkspacePool(kDGLCPU, CPUDeviceAPI::Global()) {}
}; };
void* CPUDeviceAPI::AllocWorkspace(DGLContext ctx, void* CPUDeviceAPI::AllocWorkspace(
size_t size, DGLContext ctx, size_t size, DGLDataType type_hint) {
DGLDataType type_hint) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUAllocWorkspace(size);
return td->CPUAllocWorkspace(size);
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(ctx, size); return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->AllocWorkspace(
ctx, size);
} }
void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) { void CPUDeviceAPI::FreeWorkspace(DGLContext ctx, void* data) {
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CPUFreeWorkspace(data);
return td->CPUFreeWorkspace(data);
dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data); dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()->FreeWorkspace(ctx, data);
} }
DGL_REGISTER_GLOBAL("device_api.cpu") DGL_REGISTER_GLOBAL("device_api.cpu")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get(); DeviceAPI* ptr = CPUDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
...@@ -7,11 +7,13 @@ ...@@ -7,11 +7,13 @@
#define DGL_RUNTIME_CUDA_CUDA_COMMON_H_ #define DGL_RUNTIME_CUDA_CUDA_COMMON_H_
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cusparse.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <curand.h> #include <curand.h>
#include <cusparse.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <string> #include <string>
#include "../workspace_pool.h" #include "../workspace_pool.h"
namespace dgl { namespace dgl {
...@@ -19,94 +21,89 @@ namespace runtime { ...@@ -19,94 +21,89 @@ namespace runtime {
template <typename T> template <typename T>
inline bool is_zero(T size) { inline bool is_zero(T size) {
return size == 0; return size == 0;
} }
template <> template <>
inline bool is_zero<dim3>(dim3 size) { inline bool is_zero<dim3>(dim3 size) {
return size.x == 0 || size.y == 0 || size.z == 0; return size.x == 0 || size.y == 0 || size.z == 0;
} }
#define CUDA_DRIVER_CALL(x) \ #define CUDA_DRIVER_CALL(x) \
{ \ { \
CUresult result = x; \ CUresult result = x; \
if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \ if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { \
const char *msg; \ const char* msg; \
cuGetErrorName(result, &msg); \ cuGetErrorName(result, &msg); \
LOG(FATAL) \ LOG(FATAL) << "CUDAError: " #x " failed with error: " << msg; \
<< "CUDAError: " #x " failed with error: " << msg; \
} \ } \
} }
#define CUDA_CALL(func) \ #define CUDA_CALL(func) \
{ \ { \
cudaError_t e = (func); \ cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \ << "CUDA: " << cudaGetErrorString(e); \
} }
#define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \ #define CUDA_KERNEL_CALL(kernel, nblks, nthrs, shmem, stream, ...) \
{ \ { \
if (!dgl::runtime::is_zero((nblks)) && \ if (!dgl::runtime::is_zero((nblks)) && !dgl::runtime::is_zero((nthrs))) { \
!dgl::runtime::is_zero((nthrs))) { \ (kernel)<<<(nblks), (nthrs), (shmem), (stream)>>>(__VA_ARGS__); \
(kernel) <<< (nblks), (nthrs), (shmem), (stream) >>> \ cudaError_t e = cudaGetLastError(); \
(__VA_ARGS__); \ CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
cudaError_t e = cudaGetLastError(); \ << "CUDA kernel launch error: " << cudaGetErrorString(e); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \ } \
<< "CUDA kernel launch error: " \
<< cudaGetErrorString(e); \
} \
} }
#define CUSPARSE_CALL(func) \ #define CUSPARSE_CALL(func) \
{ \ { \
cusparseStatus_t e = (func); \ cusparseStatus_t e = (func); \
CHECK(e == CUSPARSE_STATUS_SUCCESS) \ CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR: " << e; \
<< "CUSPARSE ERROR: " << e; \
} }
#define CUBLAS_CALL(func) \ #define CUBLAS_CALL(func) \
{ \ { \
cublasStatus_t e = (func); \ cublasStatus_t e = (func); \
CHECK(e == CUBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \ CHECK(e == CUBLAS_STATUS_SUCCESS) << "CUBLAS ERROR: " << e; \
} }
#define CURAND_CALL(func) \ #define CURAND_CALL(func) \
{ \ { \
curandStatus_t e = (func); \ curandStatus_t e = (func); \
CHECK(e == CURAND_STATUS_SUCCESS) \ CHECK(e == CURAND_STATUS_SUCCESS) \
<< "CURAND Error: " << dgl::runtime::curandGetErrorString(e) \ << "CURAND Error: " << dgl::runtime::curandGetErrorString(e) << " at " \
<< " at " << __FILE__ << ":" << __LINE__; \ << __FILE__ << ":" << __LINE__; \
} }
inline const char* curandGetErrorString(curandStatus_t error) { inline const char* curandGetErrorString(curandStatus_t error) {
switch (error) { switch (error) {
case CURAND_STATUS_SUCCESS: case CURAND_STATUS_SUCCESS:
return "CURAND_STATUS_SUCCESS"; return "CURAND_STATUS_SUCCESS";
case CURAND_STATUS_VERSION_MISMATCH: case CURAND_STATUS_VERSION_MISMATCH:
return "CURAND_STATUS_VERSION_MISMATCH"; return "CURAND_STATUS_VERSION_MISMATCH";
case CURAND_STATUS_NOT_INITIALIZED: case CURAND_STATUS_NOT_INITIALIZED:
return "CURAND_STATUS_NOT_INITIALIZED"; return "CURAND_STATUS_NOT_INITIALIZED";
case CURAND_STATUS_ALLOCATION_FAILED: case CURAND_STATUS_ALLOCATION_FAILED:
return "CURAND_STATUS_ALLOCATION_FAILED"; return "CURAND_STATUS_ALLOCATION_FAILED";
case CURAND_STATUS_TYPE_ERROR: case CURAND_STATUS_TYPE_ERROR:
return "CURAND_STATUS_TYPE_ERROR"; return "CURAND_STATUS_TYPE_ERROR";
case CURAND_STATUS_OUT_OF_RANGE: case CURAND_STATUS_OUT_OF_RANGE:
return "CURAND_STATUS_OUT_OF_RANGE"; return "CURAND_STATUS_OUT_OF_RANGE";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE: case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case CURAND_STATUS_LAUNCH_FAILURE: case CURAND_STATUS_LAUNCH_FAILURE:
return "CURAND_STATUS_LAUNCH_FAILURE"; return "CURAND_STATUS_LAUNCH_FAILURE";
case CURAND_STATUS_PREEXISTING_FAILURE: case CURAND_STATUS_PREEXISTING_FAILURE:
return "CURAND_STATUS_PREEXISTING_FAILURE"; return "CURAND_STATUS_PREEXISTING_FAILURE";
case CURAND_STATUS_INITIALIZATION_FAILED: case CURAND_STATUS_INITIALIZATION_FAILED:
return "CURAND_STATUS_INITIALIZATION_FAILED"; return "CURAND_STATUS_INITIALIZATION_FAILED";
case CURAND_STATUS_ARCH_MISMATCH: case CURAND_STATUS_ARCH_MISMATCH:
return "CURAND_STATUS_ARCH_MISMATCH"; return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR: case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR"; return "CURAND_STATUS_INTERNAL_ERROR";
} }
// To suppress compiler warning. // To suppress compiler warning.
return "Unrecognized curand error string"; return "Unrecognized curand error string";
......
...@@ -3,11 +3,12 @@ ...@@ -3,11 +3,12 @@
* \file cuda_device_api.cc * \file cuda_device_api.cc
* \brief GPU specific API * \brief GPU specific API
*/ */
#include <cuda_runtime.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/tensordispatch.h> #include <dgl/runtime/tensordispatch.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <dgl/runtime/registry.h>
#include <cuda_runtime.h>
#include "cuda_common.h" #include "cuda_common.h"
namespace dgl { namespace dgl {
...@@ -28,9 +29,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -28,9 +29,7 @@ class CUDADeviceAPI final : public DeviceAPI {
is_available_ = count > 0; is_available_ = count > 0;
} }
bool IsAvailable() final { bool IsAvailable() final { return is_available_; }
return is_available_;
}
void SetDevice(DGLContext ctx) final { void SetDevice(DGLContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
...@@ -39,10 +38,10 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -39,10 +38,10 @@ class CUDADeviceAPI final : public DeviceAPI {
int value = 0; int value = 0;
switch (kind) { switch (kind) {
case kExist: case kExist:
value = ( value =
cudaDeviceGetAttribute( (cudaDeviceGetAttribute(
&value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) &value, cudaDevAttrMaxThreadsPerBlock, ctx.device_id) ==
== cudaSuccess); cudaSuccess);
break; break;
case kMaxThreadsPerBlock: { case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(cudaDeviceGetAttribute(
...@@ -50,8 +49,8 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -50,8 +49,8 @@ class CUDADeviceAPI final : public DeviceAPI {
break; break;
} }
case kWarpSize: { case kWarpSize: {
CUDA_CALL(cudaDeviceGetAttribute( CUDA_CALL(
&value, cudaDevAttrWarpSize, ctx.device_id)); cudaDeviceGetAttribute(&value, cudaDevAttrWarpSize, ctx.device_id));
break; break;
} }
case kMaxSharedMemoryPerBlock: { case kMaxSharedMemoryPerBlock: {
...@@ -96,26 +95,24 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -96,26 +95,24 @@ class CUDADeviceAPI final : public DeviceAPI {
&dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id)); &dims[2], cudaDevAttrMaxBlockDimZ, ctx.device_id));
std::stringstream ss; // use json string to return multiple int values; std::stringstream ss; // use json string to return multiple int values;
ss << "[" << dims[0] <<", " << dims[1] << ", " << dims[2] << "]"; ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str(); *rv = ss.str();
return; return;
} }
} }
*rv = value; *rv = value;
} }
void* AllocDataSpace(DGLContext ctx, void* AllocDataSpace(
size_t nbytes, DGLContext ctx, size_t nbytes, size_t alignment,
size_t alignment, DGLDataType type_hint) final {
DGLDataType type_hint) final {
SetDevice(ctx); SetDevice(ctx);
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable())
return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream()); return td->CUDAAllocWorkspace(nbytes, getCurrentCUDAStream());
CHECK_EQ(256 % alignment, 0U) CHECK_EQ(256 % alignment, 0U) << "CUDA space is aligned at 256 bytes";
<< "CUDA space is aligned at 256 bytes"; void* ret;
void *ret;
CUDA_CALL(cudaMalloc(&ret, nbytes)); CUDA_CALL(cudaMalloc(&ret, nbytes));
return ret; return ret;
} }
...@@ -123,21 +120,15 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -123,21 +120,15 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeDataSpace(DGLContext ctx, void* ptr) final { void FreeDataSpace(DGLContext ctx, void* ptr) final {
SetDevice(ctx); SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CUDAFreeWorkspace(ptr);
return td->CUDAFreeWorkspace(ptr);
CUDA_CALL(cudaFree(ptr)); CUDA_CALL(cudaFree(ptr));
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(
size_t from_offset, const void* from, size_t from_offset, void* to, size_t to_offset,
void* to, size_t size, DGLContext ctx_from, DGLContext ctx_to,
size_t to_offset, DGLDataType type_hint, DGLStreamHandle stream) {
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint,
DGLStreamHandle stream) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset; to = static_cast<char*>(to) + to_offset;
...@@ -146,14 +137,15 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -146,14 +137,15 @@ class CUDADeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) { if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else { } else {
CUDA_CALL(cudaMemcpyPeerAsync(to, ctx_to.device_id, CUDA_CALL(cudaMemcpyPeerAsync(
from, ctx_from.device_id, to, ctx_to.device_id, from, ctx_from.device_id, size, cu_stream));
size, cu_stream));
} }
} else if (ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) { } else if (
ctx_from.device_type == kDGLCUDA && ctx_to.device_type == kDGLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id)); CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) { } else if (
ctx_from.device_type == kDGLCPU && ctx_to.device_type == kDGLCUDA) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id)); CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else { } else {
...@@ -161,16 +153,14 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -161,16 +153,14 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
} }
void CopyDataFromTo(const void* from, void CopyDataFromTo(
size_t from_offset, const void* from, size_t from_offset, void* to, size_t to_offset,
void* to, size_t size, DGLContext ctx_from, DGLContext ctx_to,
size_t to_offset, DGLDataType type_hint) final {
size_t size,
DGLContext ctx_from,
DGLContext ctx_to,
DGLDataType type_hint) final {
auto stream = GetStream(); auto stream = GetStream();
CopyDataFromTo(from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint, stream); CopyDataFromTo(
from, from_offset, to, to_offset, size, ctx_from, ctx_to, type_hint,
stream);
} }
DGLStreamHandle CreateStream(DGLContext ctx) { DGLStreamHandle CreateStream(DGLContext ctx) {
...@@ -187,7 +177,8 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -187,7 +177,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDA_CALL(cudaStreamDestroy(cu_stream)); CUDA_CALL(cudaStreamDestroy(cu_stream));
} }
void SyncStreamFromTo(DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) { void SyncStreamFromTo(
DGLContext ctx, DGLStreamHandle event_src, DGLStreamHandle event_dst) {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaStream_t src_stream = static_cast<cudaStream_t>(event_src); cudaStream_t src_stream = static_cast<cudaStream_t>(event_src);
cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst); cudaStream_t dst_stream = static_cast<cudaStream_t>(event_dst);
...@@ -222,54 +213,54 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -222,54 +213,54 @@ class CUDADeviceAPI final : public DeviceAPI {
*/ */
void PinData(void* ptr, size_t nbytes) { void PinData(void* ptr, size_t nbytes) {
// prevent users from pinning empty tensors or graphs // prevent users from pinning empty tensors or graphs
if (ptr == nullptr || nbytes == 0) if (ptr == nullptr || nbytes == 0) return;
return;
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault)); CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
} }
void UnpinData(void* ptr) { void UnpinData(void* ptr) {
if (ptr == nullptr) if (ptr == nullptr) return;
return;
CUDA_CALL(cudaHostUnregister(ptr)); CUDA_CALL(cudaHostUnregister(ptr));
} }
bool IsPinned(const void* ptr) override { bool IsPinned(const void* ptr) override {
// can't be a pinned tensor if CUDA context is unavailable. // can't be a pinned tensor if CUDA context is unavailable.
if (!is_available_) if (!is_available_) return false;
return false;
cudaPointerAttributes attr; cudaPointerAttributes attr;
cudaError_t status = cudaPointerGetAttributes(&attr, ptr); cudaError_t status = cudaPointerGetAttributes(&attr, ptr);
bool result = false; bool result = false;
switch (status) { switch (status) {
case cudaErrorInvalidValue: case cudaErrorInvalidValue:
// might be a normal CPU tensor in CUDA 10.2- // might be a normal CPU tensor in CUDA 10.2-
cudaGetLastError(); // clear error cudaGetLastError(); // clear error
break; break;
case cudaSuccess: case cudaSuccess:
result = (attr.type == cudaMemoryTypeHost); result = (attr.type == cudaMemoryTypeHost);
break; break;
case cudaErrorInitializationError: case cudaErrorInitializationError:
case cudaErrorNoDevice: case cudaErrorNoDevice:
case cudaErrorInsufficientDriver: case cudaErrorInsufficientDriver:
case cudaErrorInvalidDevice: case cudaErrorInvalidDevice:
// We don't want to fail in these particular cases since this function can be called // We don't want to fail in these particular cases since this function
// when users only want to run on CPU even if CUDA API is enabled, or in a forked // can be called when users only want to run on CPU even if CUDA API is
// subprocess where CUDA context cannot be initialized. So we just mark the CUDA // enabled, or in a forked subprocess where CUDA context cannot be
// context to unavailable and return. // initialized. So we just mark the CUDA context to unavailable and
is_available_ = false; // return.
cudaGetLastError(); // clear error is_available_ = false;
break; cudaGetLastError(); // clear error
default: break;
LOG(FATAL) << "error while determining memory status: " << cudaGetErrorString(status); default:
break; LOG(FATAL) << "error while determining memory status: "
<< cudaGetErrorString(status);
break;
} }
return result; return result;
} }
void* AllocWorkspace(DGLContext ctx, size_t size, DGLDataType type_hint) final { void* AllocWorkspace(
DGLContext ctx, size_t size, DGLDataType type_hint) final {
SetDevice(ctx); SetDevice(ctx);
// Redirect to PyTorch's allocator when available. // Redirect to PyTorch's allocator when available.
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
...@@ -282,8 +273,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -282,8 +273,7 @@ class CUDADeviceAPI final : public DeviceAPI {
void FreeWorkspace(DGLContext ctx, void* data) final { void FreeWorkspace(DGLContext ctx, void* data) final {
SetDevice(ctx); SetDevice(ctx);
TensorDispatcher* td = TensorDispatcher::Global(); TensorDispatcher* td = TensorDispatcher::Global();
if (td->IsAvailable()) if (td->IsAvailable()) return td->CUDAFreeWorkspace(data);
return td->CUDAFreeWorkspace(data);
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
} }
...@@ -295,14 +285,13 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -295,14 +285,13 @@ class CUDADeviceAPI final : public DeviceAPI {
} }
private: private:
static void GPUCopy(const void* from, static void GPUCopy(
void* to, const void* from, void* to, size_t size, cudaMemcpyKind kind,
size_t size, cudaStream_t stream) {
cudaMemcpyKind kind,
cudaStream_t stream) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream)); CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
if (stream == 0 && kind == cudaMemcpyDeviceToHost) { if (stream == 0 && kind == cudaMemcpyDeviceToHost) {
// only wait for the copy, when it's on the default stream, and it's to host memory // only wait for the copy, when it's on the default stream, and it's to
// host memory
CUDA_CALL(cudaStreamSynchronize(stream)); CUDA_CALL(cudaStreamSynchronize(stream));
} }
} }
...@@ -312,9 +301,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -312,9 +301,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry() CUDAThreadEntry::CUDAThreadEntry() : pool(kDGLCUDA, CUDADeviceAPI::Global()) {}
: pool(kDGLCUDA, CUDADeviceAPI::Global()) {
}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
return CUDAThreadStore::Get(); return CUDAThreadStore::Get();
...@@ -329,10 +316,10 @@ cudaStream_t getCurrentCUDAStream() { ...@@ -329,10 +316,10 @@ cudaStream_t getCurrentCUDAStream() {
} }
DGL_REGISTER_GLOBAL("device_api.cuda") DGL_REGISTER_GLOBAL("device_api.cuda")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get(); DeviceAPI* ptr = CUDADeviceAPI::Global().get();
*rv = static_cast<void*>(ptr); *rv = static_cast<void*>(ptr);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
This diff is collapsed.
...@@ -9,14 +9,14 @@ ...@@ -9,14 +9,14 @@
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include "cuda_runtime.h"
#include "cuda_common.h" #include "cuda_common.h"
#include "cuda_runtime.h"
namespace dgl { namespace dgl {
namespace runtime { namespace runtime {
namespace cuda { namespace cuda {
template<typename> template <typename>
class OrderedHashTable; class OrderedHashTable;
/*! /*!
...@@ -31,7 +31,7 @@ class OrderedHashTable; ...@@ -31,7 +31,7 @@ class OrderedHashTable;
* used. * used.
* *
* The hash table should be used in two phases, with the first being populating * The hash table should be used in two phases, with the first being populating
* the hash table with the OrderedHashTable object, and then generating this * the hash table with the OrderedHashTable object, and then generating this
* handle from it. This object can then be used to search the hash table, * handle from it. This object can then be used to search the hash table,
* to find mappings, from with CUDA code. * to find mappings, from with CUDA code.
* *
...@@ -62,7 +62,7 @@ class OrderedHashTable; ...@@ -62,7 +62,7 @@ class OrderedHashTable;
* *
* \tparam IdType The type of the IDs. * \tparam IdType The type of the IDs.
*/ */
template<typename IdType> template <typename IdType>
class DeviceOrderedHashTable { class DeviceOrderedHashTable {
public: public:
/** /**
...@@ -80,16 +80,15 @@ class DeviceOrderedHashTable { ...@@ -80,16 +80,15 @@ class DeviceOrderedHashTable {
/** /**
* \brief The index of the item when inserted into the hashtable (e.g., * \brief The index of the item when inserted into the hashtable (e.g.,
* the index within the array passed into FillWithDuplicates()). * the index within the array passed into FillWithDuplicates()).
*/ */
int64_t index; int64_t index;
}; };
typedef const Mapping* ConstIterator; typedef const Mapping* ConstIterator;
DeviceOrderedHashTable( DeviceOrderedHashTable(const DeviceOrderedHashTable& other) = default;
const DeviceOrderedHashTable& other) = default; DeviceOrderedHashTable& operator=(const DeviceOrderedHashTable& other) =
DeviceOrderedHashTable& operator=( default;
const DeviceOrderedHashTable& other) = default;
/** /**
* \brief Find the non-mutable mapping of a given key within the hash table. * \brief Find the non-mutable mapping of a given key within the hash table.
...@@ -101,8 +100,7 @@ class DeviceOrderedHashTable { ...@@ -101,8 +100,7 @@ class DeviceOrderedHashTable {
* *
* \return An iterator to the mapping. * \return An iterator to the mapping.
*/ */
inline __device__ ConstIterator Search( inline __device__ ConstIterator Search(const IdType id) const {
const IdType id) const {
const IdType pos = SearchForPosition(id); const IdType pos = SearchForPosition(id);
return &table_[pos]; return &table_[pos];
...@@ -115,8 +113,7 @@ class DeviceOrderedHashTable { ...@@ -115,8 +113,7 @@ class DeviceOrderedHashTable {
* *
* \return True if the key exists in the hashtable. * \return True if the key exists in the hashtable.
*/ */
inline __device__ bool Contains( inline __device__ bool Contains(const IdType id) const {
const IdType id) const {
IdType pos = Hash(id); IdType pos = Hash(id);
IdType delta = 1; IdType delta = 1;
...@@ -124,8 +121,8 @@ class DeviceOrderedHashTable { ...@@ -124,8 +121,8 @@ class DeviceOrderedHashTable {
if (table_[pos].key == id) { if (table_[pos].key == id) {
return true; return true;
} }
pos = Hash(pos+delta); pos = Hash(pos + delta);
delta +=1; delta += 1;
} }
return false; return false;
} }
...@@ -134,7 +131,7 @@ class DeviceOrderedHashTable { ...@@ -134,7 +131,7 @@ class DeviceOrderedHashTable {
// Must be uniform bytes for memset to work // Must be uniform bytes for memset to work
static constexpr IdType kEmptyKey = static_cast<IdType>(-1); static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
const Mapping * table_; const Mapping* table_;
size_t size_; size_t size_;
/** /**
...@@ -143,9 +140,7 @@ class DeviceOrderedHashTable { ...@@ -143,9 +140,7 @@ class DeviceOrderedHashTable {
* \param table The table stored in GPU memory. * \param table The table stored in GPU memory.
* \param size The size of the table. * \param size The size of the table.
*/ */
explicit DeviceOrderedHashTable( explicit DeviceOrderedHashTable(const Mapping* table, size_t size);
const Mapping * table,
size_t size);
/** /**
* \brief Search for an item in the hash table which is known to exist. * \brief Search for an item in the hash table which is known to exist.
...@@ -157,16 +152,15 @@ class DeviceOrderedHashTable { ...@@ -157,16 +152,15 @@ class DeviceOrderedHashTable {
* *
* \return The the position of the item in the hashtable. * \return The the position of the item in the hashtable.
*/ */
inline __device__ IdType SearchForPosition( inline __device__ IdType SearchForPosition(const IdType id) const {
const IdType id) const {
IdType pos = Hash(id); IdType pos = Hash(id);
// linearly scan for matching entry // linearly scan for matching entry
IdType delta = 1; IdType delta = 1;
while (table_[pos].key != id) { while (table_[pos].key != id) {
assert(table_[pos].key != kEmptyKey); assert(table_[pos].key != kEmptyKey);
pos = Hash(pos+delta); pos = Hash(pos + delta);
delta +=1; delta += 1;
} }
assert(pos < size_); assert(pos < size_);
...@@ -180,10 +174,7 @@ class DeviceOrderedHashTable { ...@@ -180,10 +174,7 @@ class DeviceOrderedHashTable {
* *
* \return The hash. * \return The hash.
*/ */
inline __device__ size_t Hash( inline __device__ size_t Hash(const IdType id) const { return id % size_; }
const IdType id) const {
return id % size_;
}
friend class OrderedHashTable<IdType>; friend class OrderedHashTable<IdType>;
}; };
...@@ -219,7 +210,7 @@ class DeviceOrderedHashTable { ...@@ -219,7 +210,7 @@ class DeviceOrderedHashTable {
* *
* \tparam IdType The type of the IDs. * \tparam IdType The type of the IDs.
*/ */
template<typename IdType> template <typename IdType>
class OrderedHashTable { class OrderedHashTable {
public: public:
static constexpr int kDefaultScale = 3; static constexpr int kDefaultScale = 3;
...@@ -237,9 +228,7 @@ class OrderedHashTable { ...@@ -237,9 +228,7 @@ class OrderedHashTable {
* \param stream The stream to use for initializing the hashtable. * \param stream The stream to use for initializing the hashtable.
*/ */
OrderedHashTable( OrderedHashTable(
const size_t size, const size_t size, DGLContext ctx, cudaStream_t stream,
DGLContext ctx,
cudaStream_t stream,
const int scale = kDefaultScale); const int scale = kDefaultScale);
/** /**
...@@ -248,10 +237,8 @@ class OrderedHashTable { ...@@ -248,10 +237,8 @@ class OrderedHashTable {
~OrderedHashTable(); ~OrderedHashTable();
// Disable copying // Disable copying
OrderedHashTable( OrderedHashTable(const OrderedHashTable& other) = delete;
const OrderedHashTable& other) = delete; OrderedHashTable& operator=(const OrderedHashTable& other) = delete;
OrderedHashTable& operator=(
const OrderedHashTable& other) = delete;
/** /**
* \brief Fill the hashtable with the array containing possibly duplicate * \brief Fill the hashtable with the array containing possibly duplicate
...@@ -264,11 +251,8 @@ class OrderedHashTable { ...@@ -264,11 +251,8 @@ class OrderedHashTable {
* \param stream The stream to perform operations on. * \param stream The stream to perform operations on.
*/ */
void FillWithDuplicates( void FillWithDuplicates(
const IdType * const input, const IdType* const input, const size_t num_input, IdType* const unique,
const size_t num_input, int64_t* const num_unique, cudaStream_t stream);
IdType * const unique,
int64_t * const num_unique,
cudaStream_t stream);
/** /**
* \brief Fill the hashtable with an array of unique keys. * \brief Fill the hashtable with an array of unique keys.
...@@ -278,9 +262,7 @@ class OrderedHashTable { ...@@ -278,9 +262,7 @@ class OrderedHashTable {
* \param stream The stream to perform operations on. * \param stream The stream to perform operations on.
*/ */
void FillWithUnique( void FillWithUnique(
const IdType * const input, const IdType* const input, const size_t num_input, cudaStream_t stream);
const size_t num_input,
cudaStream_t stream);
/** /**
* \brief Get a verison of the hashtable usable from device functions. * \brief Get a verison of the hashtable usable from device functions.
...@@ -290,12 +272,11 @@ class OrderedHashTable { ...@@ -290,12 +272,11 @@ class OrderedHashTable {
DeviceOrderedHashTable<IdType> DeviceHandle() const; DeviceOrderedHashTable<IdType> DeviceHandle() const;
private: private:
Mapping * table_; Mapping* table_;
size_t size_; size_t size_;
DGLContext ctx_; DGLContext ctx_;
}; };
} // namespace cuda } // namespace cuda
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
This diff is collapsed.
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
* limitations under the License. * limitations under the License.
* *
* \file nccl_api.h * \file nccl_api.h
* \brief Wrapper around NCCL routines. * \brief Wrapper around NCCL routines.
*/ */
#ifndef DGL_RUNTIME_CUDA_NCCL_API_H_ #ifndef DGL_RUNTIME_CUDA_NCCL_API_H_
#define DGL_RUNTIME_CUDA_NCCL_API_H_ #define DGL_RUNTIME_CUDA_NCCL_API_H_
...@@ -27,11 +26,14 @@ ...@@ -27,11 +26,14 @@
// if not compiling with NCCL, this class will only support communicators of // if not compiling with NCCL, this class will only support communicators of
// size 1. // size 1.
#define NCCL_UNIQUE_ID_BYTES 128 #define NCCL_UNIQUE_ID_BYTES 128
typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId; typedef struct {
char internal[NCCL_UNIQUE_ID_BYTES];
} ncclUniqueId;
typedef int ncclComm_t; typedef int ncclComm_t;
#endif #endif
#include <dgl/runtime/object.h> #include <dgl/runtime/object.h>
#include <string> #include <string>
namespace dgl { namespace dgl {
...@@ -59,17 +61,13 @@ DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId); ...@@ -59,17 +61,13 @@ DGL_DEFINE_OBJECT_REF(NCCLUniqueIdRef, NCCLUniqueId);
class NCCLCommunicator : public runtime::Object { class NCCLCommunicator : public runtime::Object {
public: public:
NCCLCommunicator( NCCLCommunicator(int size, int rank, ncclUniqueId id);
int size,
int rank,
ncclUniqueId id);
~NCCLCommunicator(); ~NCCLCommunicator();
// disable copying // disable copying
NCCLCommunicator(const NCCLCommunicator& other) = delete; NCCLCommunicator(const NCCLCommunicator& other) = delete;
NCCLCommunicator& operator=( NCCLCommunicator& operator=(const NCCLCommunicator& other);
const NCCLCommunicator& other);
ncclComm_t Get(); ncclComm_t Get();
...@@ -81,12 +79,9 @@ class NCCLCommunicator : public runtime::Object { ...@@ -81,12 +79,9 @@ class NCCLCommunicator : public runtime::Object {
* @param count The size of data to send to each rank. * @param count The size of data to send to each rank.
* @param stream The stream to operate on. * @param stream The stream to operate on.
*/ */
template<typename IdType> template <typename IdType>
void AllToAll( void AllToAll(
const IdType * send, const IdType* send, IdType* recv, int64_t count, cudaStream_t stream);
IdType * recv,
int64_t count,
cudaStream_t stream);
/** /**
* @brief Perform an all-to-all variable sized communication. * @brief Perform an all-to-all variable sized communication.
...@@ -99,13 +94,10 @@ class NCCLCommunicator : public runtime::Object { ...@@ -99,13 +94,10 @@ class NCCLCommunicator : public runtime::Object {
* @param type The type of data to send. * @param type The type of data to send.
* @param stream The stream to operate on. * @param stream The stream to operate on.
*/ */
template<typename DType> template <typename DType>
void AllToAllV( void AllToAllV(
const DType * const send, const DType* const send, const int64_t* send_prefix, DType* const recv,
const int64_t * send_prefix, const int64_t* recv_prefix, cudaStream_t stream);
DType * const recv,
const int64_t * recv_prefix,
cudaStream_t stream);
/** /**
* @brief Perform an all-to-all with sparse data (idx and value pairs). By * @brief Perform an all-to-all with sparse data (idx and value pairs). By
...@@ -124,16 +116,11 @@ class NCCLCommunicator : public runtime::Object { ...@@ -124,16 +116,11 @@ class NCCLCommunicator : public runtime::Object {
* recieve on the host. * recieve on the host.
* @param stream The stream to communicate on. * @param stream The stream to communicate on.
*/ */
template<typename IdType, typename DType> template <typename IdType, typename DType>
void SparseAllToAll( void SparseAllToAll(
const IdType * send_idx, const IdType* send_idx, const DType* send_value, const int64_t num_feat,
const DType * send_value, const int64_t* send_prefix, IdType* recv_idx, DType* recv_value,
const int64_t num_feat, const int64_t* recv_prefix, cudaStream_t stream);
const int64_t * send_prefix,
IdType * recv_idx,
DType * recv_value,
const int64_t * recv_prefix,
cudaStream_t stream);
int size() const; int size() const;
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
* \file src/runtime/dlpack_convert.cc * \file src/runtime/dlpack_convert.cc
* \brief Conversion between NDArray and DLPack. * \brief Conversion between NDArray and DLPack.
*/ */
#include <dgl/runtime/dlpack_convert.h>
#include <dlpack/dlpack.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/dlpack_convert.h>
#include <dgl/runtime/ndarray.h>
#include <dlpack/dlpack.h>
#include "runtime_base.h" #include "runtime_base.h"
// deleter for arrays used by DLPack exporter // deleter for arrays used by DLPack exporter
...@@ -69,8 +69,7 @@ NDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) { ...@@ -69,8 +69,7 @@ NDArray DLPackConvert::FromDLPack(DLManagedTensor* tensor) {
void DLPackConvert::DLPackDeleter(NDArray::Container* ptr) { void DLPackConvert::DLPackDeleter(NDArray::Container* ptr) {
// if the array is pinned by dgl, unpin it before freeing // if the array is pinned by dgl, unpin it before freeing
if (ptr->pinned_by_dgl_) if (ptr->pinned_by_dgl_) NDArray::UnpinContainer(ptr);
NDArray::UnpinContainer(ptr);
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx); DLManagedTensor* tensor = static_cast<DLManagedTensor*>(ptr->manager_ctx);
if (tensor->deleter != nullptr) { if (tensor->deleter != nullptr) {
(*tensor->deleter)(tensor); (*tensor->deleter)(tensor);
...@@ -95,7 +94,7 @@ DLManagedTensor* ContainerToDLPack(NDArray::Container* from) { ...@@ -95,7 +94,7 @@ DLManagedTensor* ContainerToDLPack(NDArray::Container* from) {
return ret; return ret;
} }
DLManagedTensor* DLPackConvert::ToDLPack(const NDArray &from) { DLManagedTensor* DLPackConvert::ToDLPack(const NDArray& from) {
return ContainerToDLPack(from.data_); return ContainerToDLPack(from.data_);
} }
...@@ -113,15 +112,14 @@ inline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept { ...@@ -113,15 +112,14 @@ inline bool IsAligned(const void* ptr, std::uintptr_t alignment) noexcept {
return !(iptr % alignment); return !(iptr % alignment);
} }
int DGLArrayFromDLPack(DLManagedTensor* from, int DGLArrayFromDLPack(DLManagedTensor* from, DGLArrayHandle* out) {
DGLArrayHandle* out) {
API_BEGIN(); API_BEGIN();
*out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from)); *out = NDArray::Internal::MoveAsDGLArray(DLPackConvert::FromDLPack(from));
API_END(); API_END();
} }
int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out, int DGLArrayToDLPack(
int alignment) { DGLArrayHandle from, DLManagedTensor** out, int alignment) {
API_BEGIN(); API_BEGIN();
auto* nd_container = reinterpret_cast<NDArray::Container*>(from); auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DGLArray* nd = &(nd_container->dl_tensor); DGLArray* nd = &(nd_container->dl_tensor);
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* \brief Module to load from dynamic shared library. * \brief Module to load from dynamic shared library.
*/ */
#include <dgl/runtime/module.h> #include <dgl/runtime/module.h>
#include <dgl/runtime/registry.h>
#include <dgl/runtime/packed_func.h> #include <dgl/runtime/packed_func.h>
#include <dgl/runtime/registry.h>
#include "module_util.h" #include "module_util.h"
#if defined(_WIN32) #if defined(_WIN32)
...@@ -25,9 +26,7 @@ class DSOModuleNode final : public ModuleNode { ...@@ -25,9 +26,7 @@ class DSOModuleNode final : public ModuleNode {
if (lib_handle_) Unload(); if (lib_handle_) Unload();
} }
const char* type_key() const final { const char* type_key() const final { return "dso"; }
return "dso";
}
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
...@@ -36,8 +35,9 @@ class DSOModuleNode final : public ModuleNode { ...@@ -36,8 +35,9 @@ class DSOModuleNode final : public ModuleNode {
if (name == runtime::symbol::dgl_module_main) { if (name == runtime::symbol::dgl_module_main) {
const char* entry_name = reinterpret_cast<const char*>( const char* entry_name = reinterpret_cast<const char*>(
GetSymbol(runtime::symbol::dgl_module_main)); GetSymbol(runtime::symbol::dgl_module_main));
CHECK(entry_name!= nullptr) CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::dgl_module_main << " is not presented"; << "Symbol " << runtime::symbol::dgl_module_main
<< " is not presented";
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name)); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(entry_name));
} else { } else {
faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str())); faddr = reinterpret_cast<BackendPackedCFunc>(GetSymbol(name.c_str()));
...@@ -48,17 +48,15 @@ class DSOModuleNode final : public ModuleNode { ...@@ -48,17 +48,15 @@ class DSOModuleNode final : public ModuleNode {
void Init(const std::string& name) { void Init(const std::string& name) {
Load(name); Load(name);
if (auto *ctx_addr = if (auto* ctx_addr = reinterpret_cast<void**>(
reinterpret_cast<void**>(GetSymbol(runtime::symbol::dgl_module_ctx))) { GetSymbol(runtime::symbol::dgl_module_ctx))) {
*ctx_addr = this; *ctx_addr = this;
} }
InitContextFunctions([this](const char* fname) { InitContextFunctions(
return GetSymbol(fname); [this](const char* fname) { return GetSymbol(fname); });
});
// Load the imported modules // Load the imported modules
const char* dev_mblob = const char* dev_mblob = reinterpret_cast<const char*>(
reinterpret_cast<const char*>( GetSymbol(runtime::symbol::dgl_dev_mblob));
GetSymbol(runtime::symbol::dgl_dev_mblob));
if (dev_mblob != nullptr) { if (dev_mblob != nullptr) {
ImportModuleBlob(dev_mblob, &imports_); ImportModuleBlob(dev_mblob, &imports_);
} }
...@@ -79,11 +77,9 @@ class DSOModuleNode final : public ModuleNode { ...@@ -79,11 +77,9 @@ class DSOModuleNode final : public ModuleNode {
} }
void* GetSymbol(const char* name) { void* GetSymbol(const char* name) {
return reinterpret_cast<void*>( return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*) GetProcAddress(lib_handle_, (LPCSTR)name)); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
} }
void Unload() { FreeLibrary(lib_handle_); }
#else #else
// Library handle // Library handle
void* lib_handle_{nullptr}; void* lib_handle_{nullptr};
...@@ -91,23 +87,18 @@ class DSOModuleNode final : public ModuleNode { ...@@ -91,23 +87,18 @@ class DSOModuleNode final : public ModuleNode {
void Load(const std::string& name) { void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
CHECK(lib_handle_ != nullptr) CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name << "Failed to load dynamic shared library " << name << " " << dlerror();
<< " " << dlerror();
}
void* GetSymbol(const char* name) {
return dlsym(lib_handle_, name);
}
void Unload() {
dlclose(lib_handle_);
} }
void* GetSymbol(const char* name) { return dlsym(lib_handle_, name); }
void Unload() { dlclose(lib_handle_); }
#endif #endif
}; };
DGL_REGISTER_GLOBAL("module.loadfile_so") DGL_REGISTER_GLOBAL("module.loadfile_so")
.set_body([](DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>(); std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]); n->Init(args[0]);
*rv = runtime::Module(n); *rv = runtime::Module(n);
}); });
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment