Unverified Commit 3e696922 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[RPC] New RPC infrastructure. (#1549)



* WIP: rpc components

* client & server

* move network package to rpc

* fix include

* fix compile

* c api

* wip: test

* add basic tests

* missing file

* [RPC] Zero copy serializer (#1517)

* zerocopy serialization

* add test for HeteroGraph

* fix lint

* remove unnecessary codes

* add comment

* lint

* lint

* disable pylint for now

* add include for win

* windows guard

* lint

* lint

* skip test on windows

* refactor

* add comment

* fix

* comment

* 1111

* fix

* Update Jenkinsfile

* [RPC] Implementation of RPC infra (#1544)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* remove client.cc and server.cc

* fix lint

* update

* update

* fix linr

* update

* fix lint

* update

* update

* update

* update

* update

* update

* update test

* update

* update test

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update comment

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7ba9cbc6
...@@ -109,6 +109,7 @@ file(GLOB_RECURSE DGL_SRC_1 ...@@ -109,6 +109,7 @@ file(GLOB_RECURSE DGL_SRC_1
src/api/*.cc src/api/*.cc
src/graph/*.cc src/graph/*.cc
src/scheduler/*.cc src/scheduler/*.cc
src/rpc/*.cc
) )
list(APPEND DGL_SRC ${DGL_SRC_1}) list(APPEND DGL_SRC ${DGL_SRC_1})
......
...@@ -6,11 +6,14 @@ ...@@ -6,11 +6,14 @@
#ifndef DGL_RUNTIME_NDARRAY_H_ #ifndef DGL_RUNTIME_NDARRAY_H_
#define DGL_RUNTIME_NDARRAY_H_ #define DGL_RUNTIME_NDARRAY_H_
#include <string>
#include <atomic> #include <atomic>
#include <vector> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "dlpack/dlpack.h"
#include "serializer.h" #include "serializer.h"
#include "shared_mem.h" #include "shared_mem.h"
...@@ -158,12 +161,12 @@ class NDArray { ...@@ -158,12 +161,12 @@ class NDArray {
* \param stream The input data stream * \param stream The input data stream
* \return Whether load is successful * \return Whether load is successful
*/ */
inline bool Load(dmlc::Stream* stream); bool Load(dmlc::Stream* stream);
/*! /*!
* \brief Save NDArray to stream * \brief Save NDArray to stream
* \param stream The output data stream * \param stream The output data stream
*/ */
inline void Save(dmlc::Stream* stream) const; void Save(dmlc::Stream* stream) const;
/*! /*!
* \brief Create a NDArray that shares the data memory with the current one. * \brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array. * \param shape The shape of the new array.
...@@ -237,6 +240,10 @@ class NDArray { ...@@ -237,6 +240,10 @@ class NDArray {
template<typename T> template<typename T>
std::vector<T> ToVector() const; std::vector<T> ToVector() const;
#ifndef _WIN32
std::shared_ptr<SharedMemory> GetSharedMem() const;
#endif // _WIN32
/*! /*!
* \brief Function to copy data from one array to another. * \brief Function to copy data from one array to another.
* \param from The source array. * \param from The source array.
...@@ -264,6 +271,7 @@ class NDArray { ...@@ -264,6 +271,7 @@ class NDArray {
*/ */
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*! /*!
* \brief Reference counted Container object used to back NDArray. * \brief Reference counted Container object used to back NDArray.
* *
...@@ -456,57 +464,6 @@ inline bool SaveDLTensor(dmlc::Stream* strm, ...@@ -456,57 +464,6 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
return true; return true;
} }
inline void NDArray::Save(dmlc::Stream* strm) const {
SaveDLTensor(strm, const_cast<DLTensor*>(operator->()));
}
inline bool NDArray::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kDGLNDArrayMagic)
<< "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format";
CHECK_EQ(ctx.device_type, kDLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
if (data_byte_size != 0) {
// strm->Read will return the total number of elements successfully read.
// Therefore if data_byte_size is zero, the CHECK below would fail.
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
}
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "smart_ptr_serializer.h" #include "smart_ptr_serializer.h"
#include "ndarray.h"
namespace dmlc { namespace dmlc {
namespace serializer { namespace serializer {
......
...@@ -27,6 +27,14 @@ class SharedMemory { ...@@ -27,6 +27,14 @@ class SharedMemory {
* and will be responsible for deleting it when the object is destroyed. * and will be responsible for deleting it when the object is destroyed.
*/ */
bool own; bool own;
/* \brief the file descripter of the shared memory. */
int fd;
/* \brief the address of the shared memory. */
void *ptr;
/* \brief the size of the shared memory. */
size_t size;
/* /*
* \brief the name of the object. * \brief the name of the object.
* *
...@@ -34,14 +42,12 @@ class SharedMemory { ...@@ -34,14 +42,12 @@ class SharedMemory {
* the file name that identifies the shared memory. * the file name that identifies the shared memory.
*/ */
std::string name; std::string name;
/* \brief the file descripter of the shared memory. */
int fd;
/* \brief the address of the shared memory. */
void *ptr;
/* \brief the size of the shared memory. */
size_t size;
public: public:
/* \brief Get the filename of shared memory file
*/
std::string GetName() const { return name; }
/* /*
* \brief constructor of the shared memory. * \brief constructor of the shared memory.
* \param name The file corresponding to the shared memory. * \param name The file corresponding to the shared memory.
......
/*!
* Copyright (c) 2020 by Contributors
* \file rpc/shared_mem_serializer.h
* \brief headers for serializer.
*/
#ifndef DGL_ZEROCOPY_SERIALIZER_H_
#define DGL_ZEROCOPY_SERIALIZER_H_
#include <dgl/runtime/ndarray.h>
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <dmlc/serializer.h>
#include <deque>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "dmlc/logging.h"
namespace dgl {
// StringStreamWithBuffer is backed up by a string. This class supports
// serializing and deserializing NDArrays stored in shared memory. If the stream
// is created for sending/recving data through network, the data pointer of the
// NDArray will be transmitted directly without and copy. Otherwise, the stream
// is for sending/recving data to another process on the same machine, so if an
// NDArray is stored in shared memory, it will just record the shared memory
// name instead of the actual data buffer.
class StringStreamWithBuffer : public dmlc::MemoryStringStream {
public:
// Buffer type. Storing NDArray to maintain the reference counting to ensure
// the liveness of data pointer
struct Buffer {
dgl::runtime::NDArray tensor = dgl::runtime::NDArray();
void* data = nullptr;
int64_t size = 0;
Buffer(const dgl::runtime::NDArray& tensor, void* data, int64_t data_size)
: tensor(tensor), data(data), size(data_size) {}
explicit Buffer(void* data) : data(data) {}
};
/*!
* \brief This constructor is for writing scenario or reading from local
* machine
* \param metadata_ptr The string to write/load from zerocopy write/load
* \param send_to_remote Whether this stream will be deserialized at remote
* machine or the local machine. If true, will record the data pointer into
* buffer list.
*
* For example:
* std::string blob;
* // Write to send to local
* StringStreamWithBuffer buf_strm(&blob, false)
* // Write to send to remote
* StringStreamWithBuffer buf_strm(&blob, true)
* // Or
* StringStreamWithBuffer buf_strm(&blob)
* // Read from local
* StringStreamWithBuffer buf_strm(&blob, false)
*/
explicit StringStreamWithBuffer(std::string* metadata_ptr,
bool send_to_remote = true)
: MemoryStringStream(metadata_ptr),
buffer_list_(),
send_to_remote_(send_to_remote) {}
/*!
* \brief This constructor is for reading from remote
* \param metadata_ptr The string to write/load from zerocopy write/load
* \param data_ptr_list list of pointer to reconstruct NDArray
*
* For example:
* std::string blob;
* std::vector<void*> data_ptr_list;
* // Read from remote sended pointer list
* StringStreamWithBuffer buf_strm(&blob, data_ptr_list)
*/
StringStreamWithBuffer(std::string* metadata_ptr,
const std::vector<void*>& data_ptr_list)
: MemoryStringStream(metadata_ptr), send_to_remote_(true) {
for (void* data : data_ptr_list) {
buffer_list_.emplace_back(data);
}
}
/*!
* \brief push NDArray into stream
* If send_to_remote=true, the NDArray will be saved to the buffer list
* If send_to_remote=false, the NDArray will be saved to the backedup string
*/
void PushNDArray(const runtime::NDArray& tensor);
/*!
* \brief pop NDArray from stream
* If send_to_remote=true, the NDArray will be reconstructed from buffer list
* If send_to_remote=false, the NDArray will be reconstructed from shared
* memory
*/
dgl::runtime::NDArray PopNDArray();
/*!
* \brief Get whether this stream is for remote usage
*/
bool send_to_remote() { return send_to_remote_; }
/*!
* \brief Get underlying buffer list
*/
const std::deque<Buffer>& buffer_list() const { return buffer_list_; }
private:
std::deque<Buffer> buffer_list_;
bool send_to_remote_;
}; // namespace dgl
} // namespace dgl
#endif // DGL_ZEROCOPY_SERIALIZER_H_
...@@ -10,6 +10,7 @@ from .backend import load_backend, backend_name ...@@ -10,6 +10,7 @@ from .backend import load_backend, backend_name
from . import function from . import function
from . import contrib from . import contrib
from . import container from . import container
from . import distributed
from . import random from . import random
from . import sampling from . import sampling
......
...@@ -49,6 +49,7 @@ class ObjectBase(object): ...@@ -49,6 +49,7 @@ class ObjectBase(object):
def __getattr__(self, name): def __getattr__(self, name):
if name == 'handle': if name == 'handle':
raise AttributeError("'handle' is a reserved attribute name that should not be used") raise AttributeError("'handle' is a reserved attribute name that should not be used")
print('in get attr:', name)
ret_val = DGLValue() ret_val = DGLValue()
ret_type_code = ctypes.c_int() ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int() ret_success = ctypes.c_int()
......
...@@ -3,3 +3,7 @@ ...@@ -3,3 +3,7 @@
from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split
from .partition import partition_graph, load_partition from .partition import partition_graph, load_partition
from .graph_partition_book import GraphPartitionBook from .graph_partition_book import GraphPartitionBook
from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, finalize_client, shutdown_servers
"""Define all the constants used by DGL rpc"""
# Maximum size of message queue in bytes
MAX_QUEUE_SIZE = 20*1024*1024*1024
"""RPC components. They are typically functions or utilities used by both
server and clients."""
import abc
import pickle
from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
from ..base import DGLError
from .. import backend as F
__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call']
REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {}
SERVICE_ID_TO_PROPERTY = {}
def read_ip_config(filename):
"""Read network configuration information of server from file.
The format of configuration file should be:
[ip] [base_port] [server_count]
172.31.40.143 30050 2
172.31.36.140 30050 2
172.31.47.147 30050 2
172.31.30.180 30050 2
Note that, DGL supports multiple backup servers that shares data with each others
on the same machine via shared-memory tensor. The server_count should be >= 1. For example,
if we set server_count to 5, it means that we have 1 main server and 4 backup servers on
current machine. Note that, the count of server on each machine can be different.
Parameters
----------
filename : str
Path of IP configuration file.
Returns
-------
dict
server namebook.
The key is server_id (int)
The value is [machine_id, ip, port, group_count] ([int, str, int, int])
e.g.,
{0:[0, '172.31.40.143', 30050, 2],
1:[0, '172.31.40.143', 30051, 2],
2:[1, '172.31.36.140', 30050, 2],
3:[1, '172.31.36.140', 30051, 2],
4:[2, '172.31.47.147', 30050, 2],
5:[2, '172.31.47.147', 30051, 2],
6:[3, '172.31.30.180', 30050, 2],
7:[3, '172.31.30.180', 30051, 2]}
"""
assert len(filename) > 0, 'filename cannot be empty.'
server_namebook = {}
try:
server_id = 0
machine_id = 0
lines = [line.rstrip('\n') for line in open(filename)]
for line in lines:
ip_addr, port, server_count = line.split(' ')
for s_count in range(int(server_count)):
server_namebook[server_id] = \
[int(machine_id), ip_addr, int(port)+s_count, int(server_count)]
server_id += 1
machine_id += 1
except ValueError:
print("Error: data format on each line should be: [ip] [base_port] [server_count]")
return server_namebook
def create_sender(max_queue_size, net_type):
"""Create rpc sender of this process.
Parameters
----------
max_queue_size : int
Maximal size (bytes) of network queue buffer.
net_type : str
Networking type. Current options are: 'socket'.
"""
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type)
def create_receiver(max_queue_size, net_type):
"""Create rpc receiver of this process.
Parameters
----------
max_queue_size : int
Maximal size (bytes) of network queue buffer.
net_type : str
Networking type. Current options are: 'socket'.
"""
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type)
def finalize_sender():
"""Finalize rpc sender of this process.
"""
_CAPI_DGLRPCFinalizeSender()
def finalize_receiver():
"""Finalize rpc receiver of this process.
"""
_CAPI_DGLRPCFinalizeReceiver()
def receiver_wait(ip_addr, port, num_senders):
"""Wait all of the senders' connections.
This api will be blocked until all the senders connect to the receiver.
Parameters
----------
ip_addr : str
receiver's IP address, e,g, '192.168.8.12'
port : int
receiver's port
num_senders : int
total number of senders
"""
_CAPI_DGLRPCReceiverWait(ip_addr, int(port), int(num_senders))
def add_receiver_addr(ip_addr, port, recv_id):
"""Add Receiver's IP address to sender's namebook.
Parameters
----------
ip_addr : str
receiver's IP address, e,g, '192.168.8.12'
port : int
receiver's listening port
recv_id : int
receiver's ID
"""
_CAPI_DGLRPCAddReceiver(ip_addr, int(port), int(recv_id))
def sender_connect():
"""Connect to all the receivers.
"""
_CAPI_DGLRPCSenderConnect()
def set_rank(rank):
"""Set the rank of this process.
If the process is a client, this is equal to client ID. Otherwise, the process
is a server and this is equal to server ID.
Parameters
----------
rank : int
Rank value
"""
_CAPI_DGLRPCSetRank(int(rank))
def get_rank():
"""Get the rank of this process.
If the process is a client, this is equal to client ID. Otherwise, the process
is a server and this is equal to server ID.
Returns
-------
int
Rank value
"""
return _CAPI_DGLRPCGetRank()
def set_machine_id(machine_id):
"""Set current machine ID
Parameters
----------
machine_id : int
Current machine ID
"""
_CAPI_DGLRPCSetMachineID(int(machine_id))
def get_machine_id():
"""Get current machine ID
Returns
-------
int
machine ID
"""
return _CAPI_DGLRPCGetMachineID()
def set_num_machines(num_machines):
"""Set number of machine
Parameters
----------
num_machines : int
Number of machine
"""
_CAPI_DGLRPCSetNumMachines(int(num_machines))
def get_num_machines():
"""Get number of machines
Returns
-------
int
number of machines
"""
return _CAPI_DGLRPCGetNumMachines()
def set_num_server(num_server):
"""Set the total number of server.
"""
_CAPI_DGLRPCSetNumServer(int(num_server))
def get_num_server():
"""Get the total number of server.
"""
return _CAPI_DGLRPCGetNumServer()
def incr_msg_seq():
"""Increment the message sequence number and return the old one.
Returns
-------
long
Message sequence number
"""
return _CAPI_DGLRPCIncrMsgSeq()
def get_msg_seq():
"""Get the current message sequence number.
Returns
-------
long
Message sequence number
"""
return _CAPI_DGLRPCGetMsgSeq()
def set_msg_seq(msg_seq):
"""Set the current message sequence number.
Parameters
----------
msg_seq : int
sequence number of current rpc message.
"""
_CAPI_DGLRPCSetMsgSeq(int(msg_seq))
def register_service(service_id, req_cls, res_cls=None):
"""Register a service to RPC.
Parameter
---------
service_id : int
Service ID.
req_cls : class
Request class.
res_cls : class, optional
Response class. If none, the service has no response.
"""
REQUEST_CLASS_TO_SERVICE_ID[req_cls] = service_id
if res_cls is not None:
RESPONSE_CLASS_TO_SERVICE_ID[res_cls] = service_id
SERVICE_ID_TO_PROPERTY[service_id] = (req_cls, res_cls)
def get_service_property(service_id):
"""Get service property.
Parameters
----------
service_id : int
Service ID.
Returns
-------
(class, class)
(Request class, Response class)
"""
return SERVICE_ID_TO_PROPERTY[service_id]
class Request:
"""Base request class"""
@abc.abstractmethod
def __getstate__(self):
"""Get serializable states.
Must be inherited by subclasses. For array members, return them as
individual return values (i.e., do not put them in containers like
dictionary or list).
"""
@abc.abstractmethod
def __setstate__(self, state):
"""Construct the request object from serialized states.
Must be inherited by subclasses.
"""
@abc.abstractmethod
def process_request(self, server_state):
"""Server-side function to process the request.
Must be inherited by subclasses.
Parameters
----------
server_state : ServerState
Server state data.
Returns
-------
Response
Response of this request or None if no response.
"""
@property
def service_id(self):
"""Get service ID."""
cls = self.__class__
sid = REQUEST_CLASS_TO_SERVICE_ID.get(cls, None)
if sid is None:
raise DGLError('Request class {} has not been registered as a service.'.format(cls))
return sid
class Response:
"""Base response class"""
@abc.abstractmethod
def __getstate__(self):
"""Get serializable states.
Must be inherited by subclasses. For array members, return them as
individual return values (i.e., do not put them in containers like
dictionary or list).
"""
@abc.abstractmethod
def __setstate__(self, state):
"""Construct the response object from serialized states.
Must be inherited by subclasses.
"""
@property
def service_id(self):
"""Get service ID."""
cls = self.__class__
sid = RESPONSE_CLASS_TO_SERVICE_ID.get(cls, None)
if sid is None:
raise DGLError('Response class {} has not been registered as a service.'.format(cls))
return sid
def serialize_to_payload(serializable):
"""Serialize an object to payloads.
The object must have implemented the __getstate__ function.
Parameters
----------
serializable : object
Any serializable object.
Returns
-------
bytearray
Serialized payload buffer.
list[Tensor]
A list of tensor payloads.
"""
state = serializable.__getstate__()
if not isinstance(state, tuple):
state = (state,)
nonarray_pos = []
nonarray_state = []
array_state = []
for i, arr_state in enumerate(state):
if F.is_tensor(arr_state):
array_state.append(arr_state)
else:
nonarray_state.append(arr_state)
nonarray_pos.append(i)
data = bytearray(pickle.dumps((nonarray_pos, nonarray_state)))
return data, array_state
def deserialize_from_payload(cls, data, tensors):
"""Deserialize and reconstruct the object from payload.
The object must have implemented the __setstate__ function.
Parameters
----------
cls : class
The object class.
data : bytearray
Serialized data buffer.
tensors : list[Tensor]
A list of tensor payloads.
Returns
-------
object
De-serialized object of class cls.
"""
pos, nonarray_state = pickle.loads(data)
state = [None] * (len(nonarray_state) + len(tensors))
for i, no_state in zip(pos, nonarray_state):
state[i] = no_state
if len(tensors) != 0:
j = 0
state_len = len(state)
for i in range(state_len):
if state[i] is None:
state[i] = tensors[j]
j += 1
if len(state) == 1:
state = state[0]
else:
state = tuple(state)
obj = cls.__new__(cls)
obj.__setstate__(state)
return obj
@register_object('rpc.RPCMessage')
class RPCMessage(ObjectBase):
"""Serialized RPC message that can be sent to remote processes.
This class can be used as argument or return value for C API.
Attributes
----------
service_id : int
The remote service ID the message wishes to invoke.
msg_seq : int
Sequence number of this message.
client_id : int
The client ID.
server_id : int
The server ID.
data : bytearray
Payload buffer carried by this request.
tensors : list[tensor]
Extra payloads in the form of tensors.
"""
def __init__(self, service_id, msg_seq, client_id, server_id, data, tensors):
self.__init_handle_by_constructor__(
_CAPI_DGLRPCCreateRPCMessage,
int(service_id),
int(msg_seq),
int(client_id),
int(server_id),
data,
[F.zerocopy_to_dgl_ndarray(tsor) for tsor in tensors])
@property
def service_id(self):
"""Get service ID."""
return _CAPI_DGLRPCMessageGetServiceId(self)
@property
def msg_seq(self):
"""Get message sequence number."""
return _CAPI_DGLRPCMessageGetMsgSeq(self)
@property
def client_id(self):
"""Get client ID."""
return _CAPI_DGLRPCMessageGetClientId(self)
@property
def server_id(self):
"""Get server ID."""
return _CAPI_DGLRPCMessageGetServerId(self)
@property
def data(self):
"""Get payload buffer."""
return _CAPI_DGLRPCMessageGetData(self)
@property
def tensors(self):
"""Get tensor payloads."""
rst = _CAPI_DGLRPCMessageGetTensors(self)
return [F.zerocopy_from_dgl_ndarray(tsor.data) for tsor in rst]
def send_request(target, request):
"""Send one request to the target server.
Serialize the given request object to an :class:`RPCMessage` and send it
out.
The operation is non-blocking -- it does not guarantee the payloads have
reached the target or even have left the sender process. However,
all the payloads (i.e., data and arrays) can be safely freed after this
function returns.
Parameters
----------
target : int
ID of target server.
request : Request
The request to send.
Raises
------
ConnectionError if there is any problem with the connection.
"""
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = target
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg)
def send_response(target, response):
"""Send one response to the target client.
Serialize the given response object to an :class:`RPCMessage` and send it
out.
The operation is non-blocking -- it does not guarantee the payloads have
reached the target or even have left the sender process. However,
all the payloads (i.e., data and arrays) can be safely freed after this
function returns.
Parameters
----------
target : int
ID of target client.
response : Response
The response to send.
Raises
------
ConnectionError if there is any problem with the connection.
"""
service_id = response.service_id
msg_seq = get_msg_seq()
client_id = target
server_id = get_rank()
data, tensors = serialize_to_payload(response)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg)
def recv_request(timeout=0):
"""Receive one request.
Receive one :class:`RPCMessage` and de-serialize it into a proper Request object.
The operation is blocking -- it returns when it receives any message
or it times out.
Parameters
----------
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
req : request
One request received from the target, or None if it times out.
client_id : int
Client' ID received from the target.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
msg = recv_rpc_message(timeout)
if msg is None:
return None
set_msg_seq(msg.msg_seq)
req_cls, _ = SERVICE_ID_TO_PROPERTY[msg.service_id]
if req_cls is None:
raise DGLError('Got request message from service ID {}, '
'but no request class is registered.'.format(msg.service_id))
req = deserialize_from_payload(req_cls, msg.data, msg.tensors)
if msg.server_id != get_rank():
raise DGLError('Got request sent to server {}, '
'different from my rank {}!'.format(msg.server_id, get_rank()))
return req, msg.client_id
def recv_response(timeout=0):
"""Receive one response.
Receive one :class:`RPCMessage` and de-serialize it into a proper Response object.
The operation is blocking -- it returns when it receives any message
or it times out.
Parameters
----------
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
res : Response
One response received from the target, or None if it times out.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
msg = recv_rpc_message(timeout)
if msg is None:
return None
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
raise DGLError('Got response message from service ID {}, '
'but no response class is registered.'.format(msg.service_id))
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != get_rank():
raise DGLError('Got reponse of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, get_rank()))
return res
def remote_call(target_and_requests, timeout=0):
"""Invoke registered services on remote servers and collect responses.
The operation is blocking -- it returns when it receives all responses
or it times out.
If the target server state is available locally, it invokes local computation
to calculate the response.
Parameters
----------
target_and_requests : list[(int, Request)]
A list of requests and the server they should be sent to.
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
list[Response]
Responses for each target-request pair. If the request does not have
response, None is placed.
Raises
------
ConnectionError if there is any problem with the connection.
"""
# TODO(chao): handle timeout
all_res = [None] * len(target_and_requests)
msgseq2pos = {}
num_res = 0
myrank = get_rank()
for pos, (target, request) in enumerate(target_and_requests):
# send request
service_id = request.service_id
msg_seq = incr_msg_seq()
client_id = get_rank()
server_id = target
data, tensors = serialize_to_payload(request)
msg = RPCMessage(service_id, msg_seq, client_id, server_id, data, tensors)
send_rpc_message(msg)
# check if has response
res_cls = get_service_property(service_id)[1]
if res_cls is not None:
num_res += 1
msgseq2pos[msg_seq] = pos
while num_res != 0:
# recv response
msg = recv_rpc_message(timeout)
num_res -= 1
_, res_cls = SERVICE_ID_TO_PROPERTY[msg.service_id]
if res_cls is None:
raise DGLError('Got response message from service ID {}, '
'but no response class is registered.'.format(msg.service_id))
res = deserialize_from_payload(res_cls, msg.data, msg.tensors)
if msg.client_id != myrank:
raise DGLError('Got reponse of request sent by client {}, '
'different from my rank {}!'.format(msg.client_id, myrank))
# set response
all_res[msgseq2pos[msg.msg_seq]] = res
return all_res
def send_rpc_message(msg):
"""Send one message to the target server.
The operation is non-blocking -- it does not guarantee the payloads have
reached the target or even have left the sender process. However,
all the payloads (i.e., data and arrays) can be safely freed after this
function returns.
The data buffer in the requst will be copied to internal buffer for actual
transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
The underlying sending threads will hold references to the tensors until
the contents have been transmitted.
Parameters
----------
msg : RPCMessage
The message to send.
Raises
------
ConnectionError if there is any problem with the connection.
"""
_CAPI_DGLRPCSendRPCMessage(msg)
def recv_rpc_message(timeout=0):
"""Receive one message.
The operation is blocking -- it returns when it receives any message
or it times out.
Parameters
----------
timeout : int, optional
The timeout value in milliseconds. If zero, wait indefinitely.
Returns
-------
msg : RPCMessage
One rpc message received from the target, or None if it times out.
Raises
------
ConnectionError if there is any problem with the connection.
"""
msg = _CAPI_DGLRPCCreateEmptyRPCMessage()
_CAPI_DGLRPCRecvRPCMessage(timeout, msg)
return msg
def finalize_server():
"""Finalize resources of current server
"""
finalize_sender()
finalize_receiver()
print("Server (%d) shutdown." % get_rank())
############### Some basic services will be defined here #############
CLIENT_REGISTER = 22451
class ClientRegisterRequest(Request):
"""This request will send client's ip to server.
Parameters
----------
ip_addr : str
client's IP address
"""
def __init__(self, ip_addr):
self.ip_addr = ip_addr
def __getstate__(self):
return self.ip_addr
def __setstate__(self, state):
self.ip_addr = state
def process_request(self, server_state):
return None # do nothing
class ClientRegisterResponse(Response):
"""This response will send assigned ID to client.
Parameters
----------
ID : int
client's ID
"""
def __init__(self, client_id):
self.client_id = client_id
def __getstate__(self):
return self.client_id
def __setstate__(self, state):
self.client_id = state
SHUT_DOWN_SERVER = 22452
class ShutDownRequest(Request):
"""Client send this request to shut-down a server.
This request has no response.
Parameters
----------
client_id : int
client's ID
"""
def __init__(self, client_id):
self.client_id = client_id
def __getstate__(self):
return self.client_id
def __setstate__(self, state):
self.client_id = state
def process_request(self, server_state):
assert self.client_id == 0
finalize_server()
exit()
_init_api("dgl.distributed.rpc")
"""Functions used by client."""
import os
import socket
from . import rpc
from .constants import MAX_QUEUE_SIZE
if os.name != 'nt':
import fcntl
import struct
def local_ip4_addr_list():
"""Return a set of IPv4 address
"""
assert os.name != 'nt', 'Do not support Windows rpc yet.'
nic = set()
for if_nidx in socket.if_nameindex():
name = if_nidx[1]
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip_addr = socket.inet_ntoa(fcntl.ioctl(
sock.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', name[:15].encode("UTF-8")))[20:24])
nic.add(ip_addr)
return nic
def get_local_machine_id(server_namebook):
"""Given server_namebook, find local machine ID
Parameters
----------
server_namebook: dict
IP address namebook of server nodes, where key is the server's ID
(start from 0) and value is the server's machine_id, IP address,
port, and group_count, e.g.,
{0:'[0, '172.31.40.143', 30050, 2],
1:'[0, '172.31.40.143', 30051, 2],
2:'[1, '172.31.36.140', 30050, 2],
3:'[1, '172.31.36.140', 30051, 2],
4:'[2, '172.31.47.147', 30050, 2],
5:'[2, '172.31.47.147', 30051, 2],
6:'[3, '172.31.30.180', 30050, 2],
7:'[3, '172.31.30.180', 30051, 2]}
Returns
-------
int
local machine ID
"""
res = 0
ip_list = local_ip4_addr_list()
for _, data in server_namebook.items():
machine_id = data[0]
ip_addr = data[1]
if ip_addr in ip_list:
res = machine_id
break
return res
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ':' + str(port)
def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Connect this client to server.
Parameters
----------
ip_config : str
Path of server IP configuration file.
max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'.' % net_type
# Register some basic service
rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse)
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook)
rpc.set_num_server(num_servers)
# group_count means how many servers
# (main_server + bakcup_server) in total inside a machine.
group_count = []
max_machine_id = 0
for server_info in server_namebook.values():
group_count.append(server_info[3])
if server_info[0] > max_machine_id:
max_machine_id = server_info[0]
num_machines = max_machine_id+1
rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook)
rpc.set_machine_id(machine_id)
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
rpc.add_receiver_addr(server_ip, server_port, server_id)
rpc.sender_connect()
# Get local usable IP address and port
ip_addr = get_local_usable_addr()
client_ip, client_port = ip_addr.split(':')
# Register client on server
# 0 is a temp ID because we haven't assigned client ID yet
rpc.set_rank(0)
register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers):
rpc.send_request(server_id, register_req)
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers)
# recv client ID from server
res = rpc.recv_response()
rpc.set_rank(res.client_id)
print("Machine (%d) client (%d) connect to server successfuly!" \
% (machine_id, rpc.get_rank()))
def finalize_client():
"""Release resources of this client."""
rpc.finalize_sender()
rpc.finalize_receiver()
def shutdown_servers():
"""Issue commands to remote servers to shut them down.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
if rpc.get_rank() == 0: # Only client_0 issue this command
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
"""Functions used by server."""
from . import rpc
from .constants import MAX_QUEUE_SIZE
def start_server(server_id, ip_config, num_clients, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Start DGL server, which will be shared with all the rpc services.
This is a blocking function -- it returns only when the server shutdown.
Parameters
----------
server_id : int
Current server ID (starts from 0).
ip_config : str
Path of IP configuration file.
num_clients : int
Total number of clients that will be connected to the server.
Note that, we do not support dynamic connection for now. It means
that when all the clients connect to server, no client will can be added
to the cluster.
max_queue_size : int
Maximal size (bytes) of server queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound because DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
"""
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_client
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type
# Register some basic services
rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse)
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config)
machine_id = server_namebook[server_id][0]
rpc.set_machine_id(machine_id)
ip_addr = server_namebook[server_id][1]
port = server_namebook[server_id][2]
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# wait all the senders connect to server.
# Once all the senders connect to server, server will not
# accept new sender's connection
print("Wait connections ...")
rpc.receiver_wait(ip_addr, port, num_clients)
print("%d clients connected!" % num_clients)
# Recv all the client's IP and assign ID to clients
addr_list = []
client_namebook = {}
for _ in range(num_clients):
req, _ = rpc.recv_request()
addr_list.append(req.ip_addr)
addr_list.sort()
for client_id, addr in enumerate(addr_list):
client_namebook[client_id] = addr
for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':')
rpc.add_receiver_addr(client_ip, client_port, client_id)
rpc.sender_connect()
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
rpc.send_response(client_id, register_res)
# main service loop
server_state = None
while True:
req, client_id = rpc.recv_request()
res = req.process_request(server_state)
if res is not None:
rpc.send_response(client_id, res)
"""Server data"""
from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
@register_object('server_state.ServerState')
class ServerState(ObjectBase):
"""Data stored in one DGL server.
In a distributed setting, DGL partitions all data associated with the graph
(e.g., node and edge features, graph structure, etc.) to multiple partitions,
each handled by one DGL server. Hence, the ServerState class includes all
the data associated with a graph partition.
Under some setup, users may want to deploy servers in a heterogeneous way
-- servers are further divided into special groups for fetching/updating
node/edge data and for sampling/querying on graph structure respectively.
In this case, the ServerState can be configured to include only node/edge
data or graph structure.
Each machine can have multiple server and client processes, but only one
server is the *master* server while all the others are backup servers. All
clients and backup servers share the state of the master server via shared
memory, which means the ServerState class must be serializable and large
bulk data (e.g., node/edge features) must be stored in NDArray to leverage
shared memory.
Attributes
----------
kv_store : dict[str, Tensor]
Key value store for tensor data
graph : DGLHeteroGraph
Graph structure of one partition
total_num_nodes : int
Total number of nodes
total_num_edges : int
Total number of edges
"""
@property
def kv_store(self):
"""Get KV store."""
return _CAPI_DGLRPCServerStateGetKVStore(self)
@property
def graph(self):
"""Get graph."""
return _CAPI_DGLRPCServerStateGetGraph(self)
@property
def total_num_nodes(self):
"""Get total number of nodes."""
return _CAPI_DGLRPCServerStateGetTotalNumNodes(self)
@property
def total_num_edges(self):
"""Get total number of edges."""
return _CAPI_DGLRPCServerStateGetTotalNumEdges(self)
def get_server_state():
"""Get server state data.
If the process is a server, this stores necessary
server-side data. Otherwise, the process is a client and it stores a cache
of the server co-located with the client (if available). When the client
invokes a RPC to the co-located server, it can thus perform computation
locally without an actual remote call.
Returns
-------
ServerState
Server state data
"""
return _CAPI_DGLRPCGetServerState()
_init_api("dgl.distributed.server_state")
...@@ -41,13 +41,6 @@ typedef void* CommunicatorHandle; ...@@ -41,13 +41,6 @@ typedef void* CommunicatorHandle;
// KVstore message handler type // KVstore message handler type
typedef void* KVMsgHandle; typedef void* KVMsgHandle;
/*! \brief Enum type for bool value with unknown */
enum BoolFlag {
kBoolUnknown = -1,
kBoolFalse = 0,
kBoolTrue = 1
};
/*! /*!
* \brief Convert a vector of NDArray to PackedFunc. * \brief Convert a vector of NDArray to PackedFunc.
*/ */
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#include <unordered_map> #include <unordered_map>
#include "./network/communicator.h" #include "../rpc/network/communicator.h"
#include "./network/socket_communicator.h" #include "../rpc/network/socket_communicator.h"
#include "./network/msg_queue.h" #include "../rpc/network/msg_queue.h"
#include "./network/common.h" #include "../rpc/network/common.h"
using dgl::network::StringPrintf; using dgl::network::StringPrintf;
using namespace dgl::runtime; using namespace dgl::runtime;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <string> #include <string>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./network/msg_queue.h" #include "../rpc/network/msg_queue.h"
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
......
/*!
* Copyright (c) 2020 by Contributors
* \file graph/serailize/zerocopy_serializer.cc
* \brief serializer implementation.
*/
#include <dgl/zerocopy_serializer.h>
#include "dgl/runtime/ndarray.h"
namespace dgl {
using dgl::runtime::NDArray;
struct RawDataTensorCtx {
std::vector<int64_t> shape;
std::vector<int64_t> stride;
DLManagedTensor tensor;
};
void RawDataTensoDLPackDeleter(DLManagedTensor* tensor) {
auto ctx = static_cast<RawDataTensorCtx*>(tensor->manager_ctx);
free(ctx->tensor.dl_tensor.data);
delete ctx;
}
NDArray CreateNDArrayFromRawData(std::vector<int64_t> shape, DLDataType dtype,
DLContext ctx, void* raw) {
auto dlm_tensor_ctx = new RawDataTensorCtx();
DLManagedTensor* dlm_tensor = &dlm_tensor_ctx->tensor;
dlm_tensor_ctx->shape = shape;
dlm_tensor->manager_ctx = dlm_tensor_ctx;
dlm_tensor->dl_tensor.shape = dmlc::BeginPtr(dlm_tensor_ctx->shape);
dlm_tensor->dl_tensor.ctx = ctx;
dlm_tensor->dl_tensor.ndim = static_cast<int>(shape.size());
dlm_tensor->dl_tensor.dtype = dtype;
dlm_tensor_ctx->stride.resize(dlm_tensor->dl_tensor.ndim, 1);
for (int i = dlm_tensor->dl_tensor.ndim - 2; i >= 0; --i) {
dlm_tensor_ctx->stride[i] =
dlm_tensor_ctx->shape[i + 1] * dlm_tensor_ctx->stride[i + 1];
}
dlm_tensor->dl_tensor.strides = dmlc::BeginPtr(dlm_tensor_ctx->stride);
dlm_tensor->dl_tensor.data = raw;
dlm_tensor->deleter = RawDataTensoDLPackDeleter;
return NDArray::FromDLPack(dlm_tensor);
}
void StringStreamWithBuffer::PushNDArray(const NDArray& tensor) {
#ifndef _WIN32
auto strm = static_cast<dmlc::Stream*>(this);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
CHECK(tensor.IsContiguous())
<< "StringStreamWithBuffer only supports contiguous tensor";
CHECK_EQ(tensor->byte_offset, 0)
<< "StringStreamWithBuffer only supports zero byte offset tensor";
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
auto mem = tensor.GetSharedMem();
if (send_to_remote_ || !mem) {
// If the stream is for remote communication or the data is not stored in
// shared memory, serialize the data content as a buffer.
strm->Write<bool>(false);
buffer_list_.emplace_back(tensor, tensor->data, data_byte_size);
} else {
CHECK(mem) << "Tried to send non-shared-memroy tensor to local "
"StringStreamWithBuffer";
// Serialize only the shared memory name.
strm->Write<bool>(true);
strm->Write(mem->GetName());
}
#else
LOG(FATAL) << "StringStreamWithBuffer is not supported on windows";
#endif // _WIN32
return;
}
NDArray StringStreamWithBuffer::PopNDArray() {
#ifndef _WIN32
auto strm = static_cast<dmlc::Stream*>(this);
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ndim)) << "Invalid DLTensor file format";
CHECK(strm->Read(&dtype)) << "Invalid DLTensor file format";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format";
}
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
bool is_shared_mem;
CHECK(strm->Read(&is_shared_mem)) << "Invalid stream read";
std::string sharedmem_name;
if (is_shared_mem) {
CHECK(!send_to_remote_) << "Invalid attempt to deserialize from shared "
"memory with send_to_remote=true";
CHECK(strm->Read(&sharedmem_name)) << "Invalid stream read";
return NDArray::EmptyShared(sharedmem_name, shape, dtype, cpu_ctx, false);
} else {
CHECK(send_to_remote_) << "Invalid attempt to deserialize from raw data "
"pointer with send_to_remote=false";
auto ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx,
buffer_list_.front().data);
buffer_list_.pop_front();
return ret;
}
#else
LOG(FATAL) << "StringStreamWithBuffer is not supported on windows";
return NDArray();
#endif // _WIN32
}
} // namespace dgl
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* \brief This file provide basic facilities for string * \brief This file provide basic facilities for string
* to make programming convenient. * to make programming convenient.
*/ */
#ifndef DGL_GRAPH_NETWORK_COMMON_H_ #ifndef DGL_RPC_NETWORK_COMMON_H_
#define DGL_GRAPH_NETWORK_COMMON_H_ #define DGL_RPC_NETWORK_COMMON_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
...@@ -130,4 +130,4 @@ void StringAppendF(std::string* dst, const char* format, ...); ...@@ -130,4 +130,4 @@ void StringAppendF(std::string* dst, const char* format, ...);
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_NETWORK_COMMON_H_ #endif // DGL_RPC_NETWORK_COMMON_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file communicator.h * \file communicator.h
* \brief Communicator for DGL distributed training. * \brief Communicator for DGL distributed training.
*/ */
#ifndef DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #ifndef DGL_RPC_NETWORK_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #define DGL_RPC_NETWORK_COMMUNICATOR_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
...@@ -170,4 +170,4 @@ class Receiver { ...@@ -170,4 +170,4 @@ class Receiver {
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #endif // DGL_RPC_NETWORK_COMMUNICATOR_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment