Unverified Commit 3e696922 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[RPC] New RPC infrastructure. (#1549)



* WIP: rpc components

* client & server

* move network package to rpc

* fix include

* fix compile

* c api

* wip: test

* add basic tests

* missing file

* [RPC] Zero copy serializer (#1517)

* zerocopy serialization

* add test for HeteroGraph

* fix lint

* remove unnecessary codes

* add comment

* lint

* lint

* disable pylint for now

* add include for win

* windows guard

* lint

* lint

* skip test on windows

* refactor

* add comment

* fix

* comment

* 1111

* fix

* Update Jenkinsfile

* [RPC] Implementation of RPC infra (#1544)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* remove client.cc and server.cc

* fix lint

* update

* update

* fix linr

* update

* fix lint

* update

* update

* update

* update

* update

* update

* update test

* update

* update test

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update comment

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7ba9cbc6
...@@ -109,6 +109,7 @@ file(GLOB_RECURSE DGL_SRC_1 ...@@ -109,6 +109,7 @@ file(GLOB_RECURSE DGL_SRC_1
src/api/*.cc src/api/*.cc
src/graph/*.cc src/graph/*.cc
src/scheduler/*.cc src/scheduler/*.cc
src/rpc/*.cc
) )
list(APPEND DGL_SRC ${DGL_SRC_1}) list(APPEND DGL_SRC ${DGL_SRC_1})
......
...@@ -6,11 +6,14 @@ ...@@ -6,11 +6,14 @@
#ifndef DGL_RUNTIME_NDARRAY_H_ #ifndef DGL_RUNTIME_NDARRAY_H_
#define DGL_RUNTIME_NDARRAY_H_ #define DGL_RUNTIME_NDARRAY_H_
#include <string>
#include <atomic> #include <atomic>
#include <vector> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "dlpack/dlpack.h"
#include "serializer.h" #include "serializer.h"
#include "shared_mem.h" #include "shared_mem.h"
...@@ -158,12 +161,12 @@ class NDArray { ...@@ -158,12 +161,12 @@ class NDArray {
* \param stream The input data stream * \param stream The input data stream
* \return Whether load is successful * \return Whether load is successful
*/ */
inline bool Load(dmlc::Stream* stream); bool Load(dmlc::Stream* stream);
/*! /*!
* \brief Save NDArray to stream * \brief Save NDArray to stream
* \param stream The output data stream * \param stream The output data stream
*/ */
inline void Save(dmlc::Stream* stream) const; void Save(dmlc::Stream* stream) const;
/*! /*!
* \brief Create a NDArray that shares the data memory with the current one. * \brief Create a NDArray that shares the data memory with the current one.
* \param shape The shape of the new array. * \param shape The shape of the new array.
...@@ -237,6 +240,10 @@ class NDArray { ...@@ -237,6 +240,10 @@ class NDArray {
template<typename T> template<typename T>
std::vector<T> ToVector() const; std::vector<T> ToVector() const;
#ifndef _WIN32
std::shared_ptr<SharedMemory> GetSharedMem() const;
#endif // _WIN32
/*! /*!
* \brief Function to copy data from one array to another. * \brief Function to copy data from one array to another.
* \param from The source array. * \param from The source array.
...@@ -264,6 +271,7 @@ class NDArray { ...@@ -264,6 +271,7 @@ class NDArray {
*/ */
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
/*! /*!
* \brief Reference counted Container object used to back NDArray. * \brief Reference counted Container object used to back NDArray.
* *
...@@ -456,57 +464,6 @@ inline bool SaveDLTensor(dmlc::Stream* strm, ...@@ -456,57 +464,6 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
return true; return true;
} }
inline void NDArray::Save(dmlc::Stream* strm) const {
SaveDLTensor(strm, const_cast<DLTensor*>(operator->()));
}
inline bool NDArray::Load(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kDGLNDArrayMagic)
<< "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format";
CHECK_EQ(ctx.device_type, kDLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
if (data_byte_size != 0) {
// strm->Read will return the total number of elements successfully read.
// Therefore if data_byte_size is zero, the CHECK below would fail.
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
}
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <dmlc/serializer.h> #include <dmlc/serializer.h>
#include "c_runtime_api.h" #include "c_runtime_api.h"
#include "smart_ptr_serializer.h" #include "smart_ptr_serializer.h"
#include "ndarray.h"
namespace dmlc { namespace dmlc {
namespace serializer { namespace serializer {
......
...@@ -27,6 +27,14 @@ class SharedMemory { ...@@ -27,6 +27,14 @@ class SharedMemory {
* and will be responsible for deleting it when the object is destroyed. * and will be responsible for deleting it when the object is destroyed.
*/ */
bool own; bool own;
/* \brief the file descripter of the shared memory. */
int fd;
/* \brief the address of the shared memory. */
void *ptr;
/* \brief the size of the shared memory. */
size_t size;
/* /*
* \brief the name of the object. * \brief the name of the object.
* *
...@@ -34,14 +42,12 @@ class SharedMemory { ...@@ -34,14 +42,12 @@ class SharedMemory {
* the file name that identifies the shared memory. * the file name that identifies the shared memory.
*/ */
std::string name; std::string name;
/* \brief the file descripter of the shared memory. */
int fd;
/* \brief the address of the shared memory. */
void *ptr;
/* \brief the size of the shared memory. */
size_t size;
public: public:
/* \brief Get the filename of shared memory file
*/
std::string GetName() const { return name; }
/* /*
* \brief constructor of the shared memory. * \brief constructor of the shared memory.
* \param name The file corresponding to the shared memory. * \param name The file corresponding to the shared memory.
......
/*!
* Copyright (c) 2020 by Contributors
* \file rpc/shared_mem_serializer.h
* \brief headers for serializer.
*/
#ifndef DGL_ZEROCOPY_SERIALIZER_H_
#define DGL_ZEROCOPY_SERIALIZER_H_
#include <dgl/runtime/ndarray.h>
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <dmlc/serializer.h>
#include <deque>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "dmlc/logging.h"
namespace dgl {
// StringStreamWithBuffer is backed up by a string. This class supports
// serializing and deserializing NDArrays stored in shared memory. If the stream
// is created for sending/recving data through network, the data pointer of the
// NDArray will be transmitted directly without and copy. Otherwise, the stream
// is for sending/recving data to another process on the same machine, so if an
// NDArray is stored in shared memory, it will just record the shared memory
// name instead of the actual data buffer.
class StringStreamWithBuffer : public dmlc::MemoryStringStream {
public:
// Buffer type. Storing NDArray to maintain the reference counting to ensure
// the liveness of data pointer
struct Buffer {
dgl::runtime::NDArray tensor = dgl::runtime::NDArray();
void* data = nullptr;
int64_t size = 0;
Buffer(const dgl::runtime::NDArray& tensor, void* data, int64_t data_size)
: tensor(tensor), data(data), size(data_size) {}
explicit Buffer(void* data) : data(data) {}
};
/*!
* \brief This constructor is for writing scenario or reading from local
* machine
* \param metadata_ptr The string to write/load from zerocopy write/load
* \param send_to_remote Whether this stream will be deserialized at remote
* machine or the local machine. If true, will record the data pointer into
* buffer list.
*
* For example:
* std::string blob;
* // Write to send to local
* StringStreamWithBuffer buf_strm(&blob, false)
* // Write to send to remote
* StringStreamWithBuffer buf_strm(&blob, true)
* // Or
* StringStreamWithBuffer buf_strm(&blob)
* // Read from local
* StringStreamWithBuffer buf_strm(&blob, false)
*/
explicit StringStreamWithBuffer(std::string* metadata_ptr,
bool send_to_remote = true)
: MemoryStringStream(metadata_ptr),
buffer_list_(),
send_to_remote_(send_to_remote) {}
/*!
* \brief This constructor is for reading from remote
* \param metadata_ptr The string to write/load from zerocopy write/load
* \param data_ptr_list list of pointer to reconstruct NDArray
*
* For example:
* std::string blob;
* std::vector<void*> data_ptr_list;
* // Read from remote sended pointer list
* StringStreamWithBuffer buf_strm(&blob, data_ptr_list)
*/
StringStreamWithBuffer(std::string* metadata_ptr,
const std::vector<void*>& data_ptr_list)
: MemoryStringStream(metadata_ptr), send_to_remote_(true) {
for (void* data : data_ptr_list) {
buffer_list_.emplace_back(data);
}
}
/*!
* \brief push NDArray into stream
* If send_to_remote=true, the NDArray will be saved to the buffer list
* If send_to_remote=false, the NDArray will be saved to the backedup string
*/
void PushNDArray(const runtime::NDArray& tensor);
/*!
* \brief pop NDArray from stream
* If send_to_remote=true, the NDArray will be reconstructed from buffer list
* If send_to_remote=false, the NDArray will be reconstructed from shared
* memory
*/
dgl::runtime::NDArray PopNDArray();
/*!
* \brief Get whether this stream is for remote usage
*/
bool send_to_remote() { return send_to_remote_; }
/*!
* \brief Get underlying buffer list
*/
const std::deque<Buffer>& buffer_list() const { return buffer_list_; }
private:
std::deque<Buffer> buffer_list_;
bool send_to_remote_;
}; // namespace dgl
} // namespace dgl
#endif // DGL_ZEROCOPY_SERIALIZER_H_
...@@ -10,6 +10,7 @@ from .backend import load_backend, backend_name ...@@ -10,6 +10,7 @@ from .backend import load_backend, backend_name
from . import function from . import function
from . import contrib from . import contrib
from . import container from . import container
from . import distributed
from . import random from . import random
from . import sampling from . import sampling
......
...@@ -49,6 +49,7 @@ class ObjectBase(object): ...@@ -49,6 +49,7 @@ class ObjectBase(object):
def __getattr__(self, name): def __getattr__(self, name):
if name == 'handle': if name == 'handle':
raise AttributeError("'handle' is a reserved attribute name that should not be used") raise AttributeError("'handle' is a reserved attribute name that should not be used")
print('in get attr:', name)
ret_val = DGLValue() ret_val = DGLValue()
ret_type_code = ctypes.c_int() ret_type_code = ctypes.c_int()
ret_success = ctypes.c_int() ret_success = ctypes.c_int()
......
...@@ -3,3 +3,7 @@ ...@@ -3,3 +3,7 @@
from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split from .dist_graph import DistGraphServer, DistGraph, node_split, edge_split
from .partition import partition_graph, load_partition from .partition import partition_graph, load_partition
from .graph_partition_book import GraphPartitionBook from .graph_partition_book import GraphPartitionBook
from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, finalize_client, shutdown_servers
"""Define all the constants used by DGL rpc"""
# Maximum size of message queue in bytes
MAX_QUEUE_SIZE = 20*1024*1024*1024
This diff is collapsed.
"""Functions used by client."""
import os
import socket
from . import rpc
from .constants import MAX_QUEUE_SIZE
if os.name != 'nt':
import fcntl
import struct
def local_ip4_addr_list():
"""Return a set of IPv4 address
"""
assert os.name != 'nt', 'Do not support Windows rpc yet.'
nic = set()
for if_nidx in socket.if_nameindex():
name = if_nidx[1]
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip_addr = socket.inet_ntoa(fcntl.ioctl(
sock.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', name[:15].encode("UTF-8")))[20:24])
nic.add(ip_addr)
return nic
def get_local_machine_id(server_namebook):
"""Given server_namebook, find local machine ID
Parameters
----------
server_namebook: dict
IP address namebook of server nodes, where key is the server's ID
(start from 0) and value is the server's machine_id, IP address,
port, and group_count, e.g.,
{0:'[0, '172.31.40.143', 30050, 2],
1:'[0, '172.31.40.143', 30051, 2],
2:'[1, '172.31.36.140', 30050, 2],
3:'[1, '172.31.36.140', 30051, 2],
4:'[2, '172.31.47.147', 30050, 2],
5:'[2, '172.31.47.147', 30051, 2],
6:'[3, '172.31.30.180', 30050, 2],
7:'[3, '172.31.30.180', 30051, 2]}
Returns
-------
int
local machine ID
"""
res = 0
ip_list = local_ip4_addr_list()
for _, data in server_namebook.items():
machine_id = data[0]
ip_addr = data[1]
if ip_addr in ip_list:
res = machine_id
break
return res
def get_local_usable_addr():
"""Get local usable IP and port
Returns
-------
str
IP address, e.g., '192.168.8.12:50051'
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
sock.connect(('10.255.255.255', 1))
ip_addr = sock.getsockname()[0]
except ValueError:
ip_addr = '127.0.0.1'
finally:
sock.close()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("", 0))
sock.listen(1)
port = sock.getsockname()[1]
sock.close()
return ip_addr + ':' + str(port)
def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Connect this client to server.
Parameters
----------
ip_config : str
Path of server IP configuration file.
max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % max_queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'.' % net_type
# Register some basic service
rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse)
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
server_namebook = rpc.read_ip_config(ip_config)
num_servers = len(server_namebook)
rpc.set_num_server(num_servers)
# group_count means how many servers
# (main_server + bakcup_server) in total inside a machine.
group_count = []
max_machine_id = 0
for server_info in server_namebook.values():
group_count.append(server_info[3])
if server_info[0] > max_machine_id:
max_machine_id = server_info[0]
num_machines = max_machine_id+1
rpc.set_num_machines(num_machines)
machine_id = get_local_machine_id(server_namebook)
rpc.set_machine_id(machine_id)
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# Get connected with all server nodes
for server_id, addr in server_namebook.items():
server_ip = addr[1]
server_port = addr[2]
rpc.add_receiver_addr(server_ip, server_port, server_id)
rpc.sender_connect()
# Get local usable IP address and port
ip_addr = get_local_usable_addr()
client_ip, client_port = ip_addr.split(':')
# Register client on server
# 0 is a temp ID because we haven't assigned client ID yet
rpc.set_rank(0)
register_req = rpc.ClientRegisterRequest(ip_addr)
for server_id in range(num_servers):
rpc.send_request(server_id, register_req)
# wait server connect back
rpc.receiver_wait(client_ip, client_port, num_servers)
# recv client ID from server
res = rpc.recv_response()
rpc.set_rank(res.client_id)
print("Machine (%d) client (%d) connect to server successfuly!" \
% (machine_id, rpc.get_rank()))
def finalize_client():
"""Release resources of this client."""
rpc.finalize_sender()
rpc.finalize_receiver()
def shutdown_servers():
"""Issue commands to remote servers to shut them down.
Raises
------
ConnectionError : If anything wrong with the connection.
"""
if rpc.get_rank() == 0: # Only client_0 issue this command
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
"""Functions used by server."""
from . import rpc
from .constants import MAX_QUEUE_SIZE
def start_server(server_id, ip_config, num_clients, \
max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Start DGL server, which will be shared with all the rpc services.
This is a blocking function -- it returns only when the server shutdown.
Parameters
----------
server_id : int
Current server ID (starts from 0).
ip_config : str
Path of IP configuration file.
num_clients : int
Total number of clients that will be connected to the server.
Note that, we do not support dynamic connection for now. It means
that when all the clients connect to server, no client will can be added
to the cluster.
max_queue_size : int
Maximal size (bytes) of server queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound because DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
"""
assert server_id >= 0, 'server_id (%d) cannot be a negative number.' % server_id
assert num_clients >= 0, 'num_client (%d) cannot be a negative number.' % num_client
assert max_queue_size > 0, 'queue_size (%d) cannot be a negative number.' % queue_size
assert net_type in ('socket'), 'net_type (%s) can only be \'socket\'' % net_type
# Register some basic services
rpc.register_service(rpc.CLIENT_REGISTER,
rpc.ClientRegisterRequest,
rpc.ClientRegisterResponse)
rpc.register_service(rpc.SHUT_DOWN_SERVER,
rpc.ShutDownRequest,
None)
rpc.set_rank(server_id)
server_namebook = rpc.read_ip_config(ip_config)
machine_id = server_namebook[server_id][0]
rpc.set_machine_id(machine_id)
ip_addr = server_namebook[server_id][1]
port = server_namebook[server_id][2]
rpc.create_sender(max_queue_size, net_type)
rpc.create_receiver(max_queue_size, net_type)
# wait all the senders connect to server.
# Once all the senders connect to server, server will not
# accept new sender's connection
print("Wait connections ...")
rpc.receiver_wait(ip_addr, port, num_clients)
print("%d clients connected!" % num_clients)
# Recv all the client's IP and assign ID to clients
addr_list = []
client_namebook = {}
for _ in range(num_clients):
req, _ = rpc.recv_request()
addr_list.append(req.ip_addr)
addr_list.sort()
for client_id, addr in enumerate(addr_list):
client_namebook[client_id] = addr
for client_id, addr in client_namebook.items():
client_ip, client_port = addr.split(':')
rpc.add_receiver_addr(client_ip, client_port, client_id)
rpc.sender_connect()
if rpc.get_rank() == 0: # server_0 send all the IDs
for client_id, _ in client_namebook.items():
register_res = rpc.ClientRegisterResponse(client_id)
rpc.send_response(client_id, register_res)
# main service loop
server_state = None
while True:
req, client_id = rpc.recv_request()
res = req.process_request(server_state)
if res is not None:
rpc.send_response(client_id, res)
"""Server data"""
from .._ffi.object import register_object, ObjectBase
from .._ffi.function import _init_api
@register_object('server_state.ServerState')
class ServerState(ObjectBase):
"""Data stored in one DGL server.
In a distributed setting, DGL partitions all data associated with the graph
(e.g., node and edge features, graph structure, etc.) to multiple partitions,
each handled by one DGL server. Hence, the ServerState class includes all
the data associated with a graph partition.
Under some setup, users may want to deploy servers in a heterogeneous way
-- servers are further divided into special groups for fetching/updating
node/edge data and for sampling/querying on graph structure respectively.
In this case, the ServerState can be configured to include only node/edge
data or graph structure.
Each machine can have multiple server and client processes, but only one
server is the *master* server while all the others are backup servers. All
clients and backup servers share the state of the master server via shared
memory, which means the ServerState class must be serializable and large
bulk data (e.g., node/edge features) must be stored in NDArray to leverage
shared memory.
Attributes
----------
kv_store : dict[str, Tensor]
Key value store for tensor data
graph : DGLHeteroGraph
Graph structure of one partition
total_num_nodes : int
Total number of nodes
total_num_edges : int
Total number of edges
"""
@property
def kv_store(self):
"""Get KV store."""
return _CAPI_DGLRPCServerStateGetKVStore(self)
@property
def graph(self):
"""Get graph."""
return _CAPI_DGLRPCServerStateGetGraph(self)
@property
def total_num_nodes(self):
"""Get total number of nodes."""
return _CAPI_DGLRPCServerStateGetTotalNumNodes(self)
@property
def total_num_edges(self):
"""Get total number of edges."""
return _CAPI_DGLRPCServerStateGetTotalNumEdges(self)
def get_server_state():
"""Get server state data.
If the process is a server, this stores necessary
server-side data. Otherwise, the process is a client and it stores a cache
of the server co-located with the client (if available). When the client
invokes a RPC to the co-located server, it can thus perform computation
locally without an actual remote call.
Returns
-------
ServerState
Server state data
"""
return _CAPI_DGLRPCGetServerState()
_init_api("dgl.distributed.server_state")
...@@ -41,13 +41,6 @@ typedef void* CommunicatorHandle; ...@@ -41,13 +41,6 @@ typedef void* CommunicatorHandle;
// KVstore message handler type // KVstore message handler type
typedef void* KVMsgHandle; typedef void* KVMsgHandle;
/*! \brief Enum type for bool value with unknown */
enum BoolFlag {
kBoolUnknown = -1,
kBoolFalse = 0,
kBoolTrue = 1
};
/*! /*!
* \brief Convert a vector of NDArray to PackedFunc. * \brief Convert a vector of NDArray to PackedFunc.
*/ */
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#include <unordered_map> #include <unordered_map>
#include "./network/communicator.h" #include "../rpc/network/communicator.h"
#include "./network/socket_communicator.h" #include "../rpc/network/socket_communicator.h"
#include "./network/msg_queue.h" #include "../rpc/network/msg_queue.h"
#include "./network/common.h" #include "../rpc/network/common.h"
using dgl::network::StringPrintf; using dgl::network::StringPrintf;
using namespace dgl::runtime; using namespace dgl::runtime;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include <string> #include <string>
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./network/msg_queue.h" #include "../rpc/network/msg_queue.h"
using dgl::runtime::NDArray; using dgl::runtime::NDArray;
......
/*!
* Copyright (c) 2020 by Contributors
* \file graph/serailize/zerocopy_serializer.cc
* \brief serializer implementation.
*/
#include <dgl/zerocopy_serializer.h>
#include "dgl/runtime/ndarray.h"
namespace dgl {
using dgl::runtime::NDArray;
struct RawDataTensorCtx {
std::vector<int64_t> shape;
std::vector<int64_t> stride;
DLManagedTensor tensor;
};
void RawDataTensoDLPackDeleter(DLManagedTensor* tensor) {
auto ctx = static_cast<RawDataTensorCtx*>(tensor->manager_ctx);
free(ctx->tensor.dl_tensor.data);
delete ctx;
}
NDArray CreateNDArrayFromRawData(std::vector<int64_t> shape, DLDataType dtype,
DLContext ctx, void* raw) {
auto dlm_tensor_ctx = new RawDataTensorCtx();
DLManagedTensor* dlm_tensor = &dlm_tensor_ctx->tensor;
dlm_tensor_ctx->shape = shape;
dlm_tensor->manager_ctx = dlm_tensor_ctx;
dlm_tensor->dl_tensor.shape = dmlc::BeginPtr(dlm_tensor_ctx->shape);
dlm_tensor->dl_tensor.ctx = ctx;
dlm_tensor->dl_tensor.ndim = static_cast<int>(shape.size());
dlm_tensor->dl_tensor.dtype = dtype;
dlm_tensor_ctx->stride.resize(dlm_tensor->dl_tensor.ndim, 1);
for (int i = dlm_tensor->dl_tensor.ndim - 2; i >= 0; --i) {
dlm_tensor_ctx->stride[i] =
dlm_tensor_ctx->shape[i + 1] * dlm_tensor_ctx->stride[i + 1];
}
dlm_tensor->dl_tensor.strides = dmlc::BeginPtr(dlm_tensor_ctx->stride);
dlm_tensor->dl_tensor.data = raw;
dlm_tensor->deleter = RawDataTensoDLPackDeleter;
return NDArray::FromDLPack(dlm_tensor);
}
void StringStreamWithBuffer::PushNDArray(const NDArray& tensor) {
#ifndef _WIN32
auto strm = static_cast<dmlc::Stream*>(this);
strm->Write(tensor->ndim);
strm->Write(tensor->dtype);
int ndim = tensor->ndim;
strm->WriteArray(tensor->shape, ndim);
CHECK(tensor.IsContiguous())
<< "StringStreamWithBuffer only supports contiguous tensor";
CHECK_EQ(tensor->byte_offset, 0)
<< "StringStreamWithBuffer only supports zero byte offset tensor";
int type_bytes = tensor->dtype.bits / 8;
int64_t num_elems = 1;
for (int i = 0; i < ndim; ++i) {
num_elems *= tensor->shape[i];
}
int64_t data_byte_size = type_bytes * num_elems;
auto mem = tensor.GetSharedMem();
if (send_to_remote_ || !mem) {
// If the stream is for remote communication or the data is not stored in
// shared memory, serialize the data content as a buffer.
strm->Write<bool>(false);
buffer_list_.emplace_back(tensor, tensor->data, data_byte_size);
} else {
CHECK(mem) << "Tried to send non-shared-memroy tensor to local "
"StringStreamWithBuffer";
// Serialize only the shared memory name.
strm->Write<bool>(true);
strm->Write(mem->GetName());
}
#else
LOG(FATAL) << "StringStreamWithBuffer is not supported on windows";
#endif // _WIN32
return;
}
NDArray StringStreamWithBuffer::PopNDArray() {
#ifndef _WIN32
auto strm = static_cast<dmlc::Stream*>(this);
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ndim)) << "Invalid DLTensor file format";
CHECK(strm->Read(&dtype)) << "Invalid DLTensor file format";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format";
}
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
bool is_shared_mem;
CHECK(strm->Read(&is_shared_mem)) << "Invalid stream read";
std::string sharedmem_name;
if (is_shared_mem) {
CHECK(!send_to_remote_) << "Invalid attempt to deserialize from shared "
"memory with send_to_remote=true";
CHECK(strm->Read(&sharedmem_name)) << "Invalid stream read";
return NDArray::EmptyShared(sharedmem_name, shape, dtype, cpu_ctx, false);
} else {
CHECK(send_to_remote_) << "Invalid attempt to deserialize from raw data "
"pointer with send_to_remote=false";
auto ret = CreateNDArrayFromRawData(shape, dtype, cpu_ctx,
buffer_list_.front().data);
buffer_list_.pop_front();
return ret;
}
#else
LOG(FATAL) << "StringStreamWithBuffer is not supported on windows";
return NDArray();
#endif // _WIN32
}
} // namespace dgl
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* \brief This file provide basic facilities for string * \brief This file provide basic facilities for string
* to make programming convenient. * to make programming convenient.
*/ */
#ifndef DGL_GRAPH_NETWORK_COMMON_H_ #ifndef DGL_RPC_NETWORK_COMMON_H_
#define DGL_GRAPH_NETWORK_COMMON_H_ #define DGL_RPC_NETWORK_COMMON_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
...@@ -130,4 +130,4 @@ void StringAppendF(std::string* dst, const char* format, ...); ...@@ -130,4 +130,4 @@ void StringAppendF(std::string* dst, const char* format, ...);
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_NETWORK_COMMON_H_ #endif // DGL_RPC_NETWORK_COMMON_H_
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
* \file communicator.h * \file communicator.h
* \brief Communicator for DGL distributed training. * \brief Communicator for DGL distributed training.
*/ */
#ifndef DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #ifndef DGL_RPC_NETWORK_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #define DGL_RPC_NETWORK_COMMUNICATOR_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
...@@ -170,4 +170,4 @@ class Receiver { ...@@ -170,4 +170,4 @@ class Receiver {
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
#endif // DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #endif // DGL_RPC_NETWORK_COMMUNICATOR_H_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment