Unverified Commit 3e696922 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[RPC] New RPC infrastructure. (#1549)



* WIP: rpc components

* client & server

* move network package to rpc

* fix include

* fix compile

* c api

* wip: test

* add basic tests

* missing file

* [RPC] Zero copy serializer (#1517)

* zerocopy serialization

* add test for HeteroGraph

* fix lint

* remove unnecessary codes

* add comment

* lint

* lint

* disable pylint for now

* add include for win

* windows guard

* lint

* lint

* skip test on windows

* refactor

* add comment

* fix

* comment

* 1111

* fix

* Update Jenkinsfile

* [RPC] Implementation of RPC infra (#1544)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* remove client.cc and server.cc

* fix lint

* update

* update

* fix linr

* update

* fix lint

* update

* update

* update

* update

* update

* update

* update test

* update

* update test

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update comment

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 7ba9cbc6
......@@ -3,8 +3,8 @@
* \file msg_queue.h
* \brief Message queue for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#define DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#ifndef DGL_RPC_NETWORK_MSG_QUEUE_H_
#define DGL_RPC_NETWORK_MSG_QUEUE_H_
#include <dgl/runtime/ndarray.h>
......@@ -179,4 +179,4 @@ class MessageQueue {
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#endif // DGL_RPC_NETWORK_MSG_QUEUE_H_
......@@ -3,8 +3,8 @@
* \file communicator.h
* \brief SocketCommunicator for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_
#ifndef DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
#include <thread>
#include <vector>
......@@ -221,4 +221,4 @@ class SocketReceiver : public Receiver {
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_
#endif // DGL_RPC_NETWORK_SOCKET_COMMUNICATOR_H_
......@@ -3,8 +3,8 @@
* \file tcp_socket.h
* \brief TCP socket for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_TCP_SOCKET_H_
#define DGL_GRAPH_NETWORK_TCP_SOCKET_H_
#ifndef DGL_RPC_NETWORK_TCP_SOCKET_H_
#define DGL_RPC_NETWORK_TCP_SOCKET_H_
#ifdef _WIN32
#include <winsock2.h>
......@@ -130,4 +130,4 @@ class TCPSocket {
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_TCP_SOCKET_H_
#endif // DGL_RPC_NETWORK_TCP_SOCKET_H_
/*!
* Copyright (c) 2020 by Contributors
* \file rpc/rpc.cc
* \brief Implementation of RPC utilities used by both server and client sides.
*/
#include "./rpc.h"
#include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h>
#include <dgl/zerocopy_serializer.h>
#include "../c_api_common.h"
using dgl::network::StringPrintf;
using namespace dgl::runtime;
namespace dgl {
namespace rpc {
RPCStatus SendRPCMessage(const RPCMessage& msg) {
std::shared_ptr<std::string> zerocopy_blob(new std::string());
StringStreamWithBuffer zc_write_strm(zerocopy_blob.get());
static_cast<dmlc::Stream *>(&zc_write_strm)->Write(msg);
int32_t ndarray_count = msg.tensors.size();
zerocopy_blob->append(
reinterpret_cast<char*>(&ndarray_count),
sizeof(int32_t));
network::Message rpc_meta_msg;
rpc_meta_msg.data = const_cast<char*>(zerocopy_blob->data());
rpc_meta_msg.size = zerocopy_blob->size();
rpc_meta_msg.deallocator = [zerocopy_blob](network::Message*) {};
CHECK_EQ(RPCContext::ThreadLocal()->sender->Send(
rpc_meta_msg, msg.server_id), ADD_SUCCESS);
// send real ndarray data
for (auto ptr : zc_write_strm.buffer_list()) {
network::Message ndarray_data_msg;
ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
ndarray_data_msg.size = ptr.size;
NDArray tensor = ptr.tensor;
ndarray_data_msg.deallocator = [tensor](network::Message*) {};
CHECK_EQ(RPCContext::ThreadLocal()->sender->Send(
ndarray_data_msg, msg.server_id), ADD_SUCCESS);
}
return kRPCSuccess;
}
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
// ignore timeout now
CHECK_EQ(timeout, 0) << "rpc cannot support timeout now.";
network::Message rpc_meta_msg;
int send_id;
CHECK_EQ(RPCContext::ThreadLocal()->receiver->Recv(
&rpc_meta_msg, &send_id), REMOVE_SUCCESS);
// Copy the data for now, can be optimized later
std::string zerocopy_blob(
rpc_meta_msg.data,
rpc_meta_msg.size-sizeof(int32_t));
char* count_ptr = rpc_meta_msg.data+rpc_meta_msg.size-sizeof(int32_t);
int32_t ndarray_count = *(reinterpret_cast<int32_t*>(count_ptr));
rpc_meta_msg.deallocator(&rpc_meta_msg);
// Recv real ndarray data
std::vector<void* > buffer_list(ndarray_count);
for (int i = 0; i < ndarray_count; ++i) {
network::Message ndarray_data_msg;
CHECK_EQ(RPCContext::ThreadLocal()->receiver->RecvFrom(
&ndarray_data_msg, send_id), REMOVE_SUCCESS);
buffer_list[i] = ndarray_data_msg.data;
}
StringStreamWithBuffer zc_read_strm(&zerocopy_blob, buffer_list);
static_cast<dmlc::Stream *>(&zc_read_strm)->Read(msg);
return kRPCSuccess;
}
//////////////////////////// C APIs ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
if (type.compare("socket") == 0) {
RPCContext::ThreadLocal()->sender = std::make_shared<network::SocketSender>(msg_queue_size);
} else {
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
}
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
if (type.compare("socket") == 0) {
RPCContext::ThreadLocal()->receiver = std::make_shared<network::SocketReceiver>(msg_queue_size);
} else {
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
}
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
RPCContext::ThreadLocal()->sender->Finalize();
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFinalizeReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
RPCContext::ThreadLocal()->receiver->Finalize();
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReceiverWait")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
int num_sender = args[2];
std::string addr;
if (RPCContext::ThreadLocal()->receiver->Type() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << RPCContext::ThreadLocal()->receiver->Type();
}
if (RPCContext::ThreadLocal()->receiver->Wait(addr.c_str(), num_sender) == false) {
LOG(FATAL) << "Wait sender socket failed.";
}
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCAddReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
int recv_id = args[2];
std::string addr;
if (RPCContext::ThreadLocal()->sender->Type() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << RPCContext::ThreadLocal()->sender->Type();
}
RPCContext::ThreadLocal()->sender->AddReceiver(addr.c_str(), recv_id);
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSenderConnect")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
if (RPCContext::ThreadLocal()->sender->Connect() == false) {
LOG(FATAL) << "Sender connection failed.";
}
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetRank")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t rank = args[0];
RPCContext::ThreadLocal()->rank = rank;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetRank")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->rank;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumServer")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_servers = args[0];
*rv = RPCContext::ThreadLocal()->num_servers = num_servers;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumServer")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->num_servers;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCIncrMsgSeq")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = (RPCContext::ThreadLocal()->msg_seq)++;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMsgSeq")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->msg_seq;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMsgSeq")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int64_t msg_seq = args[0];
RPCContext::ThreadLocal()->msg_seq = msg_seq;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetMachineID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->machine_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetMachineID")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t machine_id = args[0];
RPCContext::ThreadLocal()->machine_id = machine_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetNumMachines")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = RPCContext::ThreadLocal()->num_machines;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSetNumMachines")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int32_t num_machines = args[0];
RPCContext::ThreadLocal()->num_machines = num_machines;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCSendRPCMessage")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
RPCMessageRef msg = args[0];
*rv = SendRPCMessage(*(msg.sptr()));
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCRecvRPCMessage")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int32_t timeout = args[0];
RPCMessageRef msg = args[1];
*rv = RecvRPCMessage(msg.sptr().get(), timeout);
});
//////////////////////////// RPCMessage ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateEmptyRPCMessage")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<RPCMessage> rst(new RPCMessage);
*rv = rst;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateRPCMessage")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::shared_ptr<RPCMessage> rst(new RPCMessage);
rst->service_id = args[0];
rst->msg_seq = args[1];
rst->client_id = args[2];
rst->server_id = args[3];
const std::string data = args[4]; // directly assigning string value raises errors :(
rst->data = data;
rst->tensors = ListValueToVector<NDArray>(args[5]);
*rv = rst;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServiceId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->service_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetMsgSeq")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->msg_seq;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetClientId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->client_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetServerId")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
*rv = msg->server_id;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetData")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
DGLByteArray barr{msg->data.c_str(), msg->data.size()};
*rv = barr;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCMessageGetTensors")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const RPCMessageRef msg = args[0];
List<Value> ret;
for (size_t i = 0; i < msg->tensors.size(); ++i) {
ret.push_back(Value(MakeValue(msg->tensors[i])));
}
*rv = ret;
});
//////////////////////////// ServerState ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
auto st = RPCContext::ThreadLocal()->server_state;
CHECK(st) << "Server state has not been initialized.";
*rv = st;
});
} // namespace rpc
} // namespace dgl
/*!
* Copyright (c) 2020 by Contributors
* \file rpc/rpc.h
* \brief Common headers for remote process call (RPC).
*/
#ifndef DGL_RPC_RPC_H_
#define DGL_RPC_RPC_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/thread_local.h>
#include <cstdint>
#include <memory>
#include <vector>
#include <string>
#include "./network/communicator.h"
#include "./network/socket_communicator.h"
#include "./network/msg_queue.h"
#include "./network/common.h"
#include "./server_state.h"
namespace dgl {
namespace rpc {
// Communicator handler type
typedef void* CommunicatorHandle;
/*! \brief Context information for RPC communication */
struct RPCContext {
/*!
* \brief Rank of this process.
*
* If the process is a client, this is equal to client ID. Otherwise, the process
* is a server and this is equal to server ID.
*/
int32_t rank = -1;
/*!
* \brief Cuurent machine ID
*/
int32_t machine_id = -1;
/*!
* \brief Total number of machines.
*/
int32_t num_machines = 0;
/*!
* \brief Message sequence number.
*/
int64_t msg_seq = 0;
/*!
* \brief Total number of server.
*/
int32_t num_servers = 0;
/*!
* \brief Sender communicator.
*/
std::shared_ptr<network::Sender> sender;
/*!
* \brief Receiver communicator.
*/
std::shared_ptr<network::Receiver> receiver;
/*!
* \brief Server state data.
*
* If the process is a server, this stores necessary
* server-side data. Otherwise, the process is a client and it stores a cache
* of the server co-located with the client (if available). When the client
* invokes a RPC to the co-located server, it can thus perform computation
* locally without an actual remote call.
*/
std::shared_ptr<ServerState> server_state;
/*! \brief Get the thread-local RPC context structure */
static RPCContext *ThreadLocal() {
return dmlc::ThreadLocalStore<RPCContext>::Get();
}
};
/*! \brief RPC message data structure
*
* This structure is exposed to Python and can be used as argument or return value
* in C API.
*/
struct RPCMessage : public runtime::Object {
/*! \brief Service ID */
int32_t service_id;
/*! \brief Sequence number of this message. */
int64_t msg_seq;
/*! \brief Client ID. */
int32_t client_id;
/*! \brief Server ID. */
int32_t server_id;
/*! \brief Payload buffer carried by this request.*/
std::string data;
/*! \brief Extra payloads in the form of tensors.*/
std::vector<runtime::NDArray> tensors;
bool Load(dmlc::Stream* stream) {
stream->Read(&service_id);
stream->Read(&msg_seq);
stream->Read(&client_id);
stream->Read(&server_id);
stream->Read(&data);
stream->Read(&tensors);
return true;
}
void Save(dmlc::Stream* stream) const {
stream->Write(service_id);
stream->Write(msg_seq);
stream->Write(client_id);
stream->Write(server_id);
stream->Write(data);
stream->Write(tensors);
}
static constexpr const char* _type_key = "rpc.RPCMessage";
DGL_DECLARE_OBJECT_TYPE_INFO(RPCMessage, runtime::Object);
};
DGL_DEFINE_OBJECT_REF(RPCMessageRef, RPCMessage);
/*! \brief RPC status flag */
enum RPCStatus {
kRPCSuccess = 0,
kRPCTimeOut,
};
/*!
* \brief Send out one RPC message.
*
* The operation is non-blocking -- it does not guarantee the payloads have
* reached the target or even have left the sender process. However,
* all the payloads (i.e., data and arrays) can be safely freed after this
* function returns.
*
* The data buffer in the requst will be copied to internal buffer for actual
* transmission, while no memory copy for tensor payloads (a.k.a. zero-copy).
* The underlying sending threads will hold references to the tensors until
* the contents have been transmitted.
*
* \param msg RPC message to send
* \return status flag
*/
RPCStatus SendRPCMessage(const RPCMessage& msg);
/*!
* \brief Receive one RPC message.
*
* The operation is blocking -- it returns when it receives any message
*
* \param msg The received message
* \param timeout The timeout value in milliseconds. If zero, wait indefinitely.
* \return status flag
*/
RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout = 0);
} // namespace rpc
} // namespace dgl
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, dgl::rpc::RPCMessage, true);
} // namespace dmlc
#endif // DGL_RPC_RPC_H_
/*!
* Copyright (c) 2020 by Contributors
* \file rpc/server_state.h
* \brief Implementation of RPC utilities used by both server and client sides.
*/
#ifndef DGL_RPC_SERVER_STATE_H_
#define DGL_RPC_SERVER_STATE_H_
#include <dgl/runtime/object.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/base_heterograph.h>
#include <unordered_map>
#include <string>
namespace dgl {
namespace rpc {
/*!
* \brief Data stored in one DGL server.
*
* In a distributed setting, DGL partitions all data associated with the graph
* (e.g., node and edge features, graph structure, etc.) to multiple partitions,
* each handled by one DGL server. Hence, the ServerState class includes all
* the data associated with a graph partition.
*
* Under some setup, users may want to deploy servers in a heterogeneous way
* -- servers are further divided into special groups for fetching/updating
* node/edge data and for sampling/querying on graph structure respectively.
* In this case, the ServerState can be configured to include only node/edge
* data or graph structure.
*
* Each machine can have multiple server and client processes, but only one
* server is the *master* server while all the others are backup servers. All
* clients and backup servers share the state of the master server via shared
* memory, which means the ServerState class must be serializable and large
* bulk data (e.g., node/edge features) must be stored in NDArray to leverage
* shared memory.
*/
struct ServerState : public runtime::Object {
/*! \brief Key value store for NDArray data */
std::unordered_map<std::string, runtime::NDArray> kv_store;
/*! \brief Graph structure of one partition */
HeteroGraphPtr graph;
/*! \brief Total number of nodes */
int64_t total_num_nodes = 0;
/*! \brief Total number of edges */
int64_t total_num_edges = 0;
static constexpr const char* _type_key = "server_state.ServerState";
DGL_DECLARE_OBJECT_TYPE_INFO(ServerState, runtime::Object);
};
DGL_DEFINE_OBJECT_REF(ServerStateRef, ServerState);
} // namespace rpc
} // namespace dgl
#endif // DGL_RPC_SERVER_STATE_H_
......@@ -8,6 +8,8 @@
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/shared_mem.h>
#include <dgl/zerocopy_serializer.h>
#include "runtime_base.h"
// deleter for arrays used by DLPack exporter
......@@ -278,6 +280,76 @@ template std::vector<uint64_t> NDArray::ToVector<uint64_t>() const;
template std::vector<float> NDArray::ToVector<float>() const;
template std::vector<double> NDArray::ToVector<double>() const;
#ifndef _WIN32
std::shared_ptr<SharedMemory> NDArray::GetSharedMem() const {
return this->data_->mem;
}
#endif // _WIN32
void NDArray::Save(dmlc::Stream* strm) const {
auto zc_strm = dynamic_cast<StringStreamWithBuffer*>(strm);
if (zc_strm) {
zc_strm->PushNDArray(*this);
return;
}
SaveDLTensor(strm, const_cast<DLTensor*>(operator->()));
}
bool NDArray::Load(dmlc::Stream* strm) {
auto zc_strm = dynamic_cast<StringStreamWithBuffer*>(strm);
if (zc_strm) {
*this = zc_strm->PopNDArray();
return true;
}
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved))
<< "Invalid DLTensor file format";
CHECK(header == kDGLNDArrayMagic)
<< "Invalid DLTensor file format";
DLContext ctx;
int ndim;
DLDataType dtype;
CHECK(strm->Read(&ctx))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&ndim))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&dtype))
<< "Invalid DLTensor file format";
CHECK_EQ(ctx.device_type, kDLCPU)
<< "Invalid DLTensor context: can only save as CPU tensor";
std::vector<int64_t> shape(ndim);
if (ndim != 0) {
CHECK(strm->ReadArray(&shape[0], ndim))
<< "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, ctx);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
num_elems *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == num_elems * elem_bytes)
<< "Invalid DLTensor file format";
if (data_byte_size != 0) {
// strm->Read will return the total number of elements successfully read.
// Therefore if data_byte_size is zero, the CHECK below would fail.
CHECK(strm->Read(ret->data, data_byte_size))
<< "Invalid DLTensor file format";
}
if (!DMLC_IO_NO_ENDIAN_SWAP) {
dmlc::ByteSwap(ret->data, elem_bytes, num_elems);
}
*this = ret;
return true;
}
} // namespace runtime
} // namespace dgl
......
import os
import time
import dgl
import backend as F
import unittest, pytest
from numpy.testing import assert_array_equal
INTEGER = 2
STR = 'hello world!'
HELLO_SERVICE_ID = 901231
TENSOR = F.zeros((10, 10), F.int64, F.cpu())
def test_rank():
dgl.distributed.set_rank(2)
assert dgl.distributed.get_rank() == 2
def test_msg_seq():
from dgl.distributed.rpc import get_msg_seq, incr_msg_seq
assert get_msg_seq() == 0
incr_msg_seq()
incr_msg_seq()
incr_msg_seq()
assert get_msg_seq() == 3
def foo(x, y):
assert x == 123
assert y == "abc"
class MyRequest(dgl.distributed.Request):
def __init__(self):
self.x = 123
self.y = "abc"
self.z = F.randn((3, 4))
self.foo = foo
def __getstate__(self):
return self.x, self.y, self.z, self.foo
def __setstate__(self, state):
self.x, self.y, self.z, self.foo = state
def process_request(self, server_state):
pass
class MyResponse(dgl.distributed.Response):
def __init__(self):
self.x = 432
def __getstate__(self):
return self.x
def __setstate__(self, state):
self.x = state
def simple_func(tensor):
return tensor
class HelloResponse(dgl.distributed.Response):
def __init__(self, hello_str, integer, tensor):
self.hello_str = hello_str
self.integer = integer
self.tensor = tensor
def __getstate__(self):
return self.hello_str, self.integer, self.tensor
def __setstate__(self, state):
self.hello_str, self.integer, self.tensor = state
class HelloRequest(dgl.distributed.Request):
def __init__(self, hello_str, integer, tensor, func):
self.hello_str = hello_str
self.integer = integer
self.tensor = tensor
self.func = func
def __getstate__(self):
return self.hello_str, self.integer, self.tensor, self.func
def __setstate__(self, state):
self.hello_str, self.integer, self.tensor, self.func = state
def process_request(self, server_state):
assert self.hello_str == STR
assert self.integer == INTEGER
new_tensor = self.func(self.tensor)
res = HelloResponse(self.hello_str, self.integer, new_tensor)
return res
def start_server():
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.start_server(server_id=0, ip_config='ip_config.txt', num_clients=1)
def start_client():
dgl.distributed.register_service(HELLO_SERVICE_ID, HelloRequest, HelloResponse)
dgl.distributed.connect_to_server(ip_config='ip_config.txt')
req = HelloRequest(STR, INTEGER, TENSOR, simple_func)
# test send and recv
dgl.distributed.send_request(0, req)
res = dgl.distributed.recv_response()
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# test remote_call
target_and_requests = []
for i in range(10):
target_and_requests.append((0, req))
res_list = dgl.distributed.remote_call(target_and_requests)
for res in res_list:
assert res.hello_str == STR
assert res.integer == INTEGER
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
# clean up
dgl.distributed.shutdown_servers()
dgl.distributed.finalize_client()
def test_serialize():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
SERVICE_ID = 12345
dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
req = MyRequest()
data, tensors = serialize_to_payload(req)
req1 = deserialize_from_payload(MyRequest, data, tensors)
req1.foo(req1.x, req1.y)
assert req.x == req1.x
assert req.y == req1.y
assert F.array_equal(req.z, req1.z)
res = MyResponse()
data, tensors = serialize_to_payload(res)
res1 = deserialize_from_payload(MyResponse, data, tensors)
assert res.x == res1.x
def test_rpc_msg():
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
SERVICE_ID = 32452
dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
req = MyRequest()
data, tensors = serialize_to_payload(req)
rpcmsg = RPCMessage(SERVICE_ID, 23, 0, 1, data, tensors)
assert rpcmsg.service_id == SERVICE_ID
assert rpcmsg.msg_seq == 23
assert rpcmsg.client_id == 0
assert rpcmsg.server_id == 1
assert len(rpcmsg.data) == len(data)
assert len(rpcmsg.tensors) == 1
assert F.array_equal(rpcmsg.tensors[0], req.z)
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc():
ip_config = open("ip_config.txt", "w")
ip_config.write('127.0.0.1 30050 1\n')
ip_config.close()
pid = os.fork()
if pid == 0:
start_server()
else:
time.sleep(1)
start_client()
if __name__ == '__main__':
test_rank()
test_msg_seq()
test_serialize()
test_rpc_msg()
test_rpc()
......@@ -8,7 +8,7 @@
#include <thread>
#include <vector>
#include "../src/graph/network/msg_queue.h"
#include "../src/rpc/network/msg_queue.h"
using std::string;
using dgl::network::Message;
......
......@@ -9,8 +9,8 @@
#include <thread>
#include <vector>
#include "../src/graph/network/msg_queue.h"
#include "../src/graph/network/socket_communicator.h"
#include "../src/rpc/network/msg_queue.h"
#include "../src/rpc/network/socket_communicator.h"
using std::string;
......
......@@ -7,7 +7,7 @@
#include <string>
#include <vector>
#include "../src/graph/network/common.h"
#include "../src/rpc/network/common.h"
using dgl::network::SplitStringUsing;
using dgl::network::StringPrintf;
......
#include <dgl/array.h>
#include <dgl/immutable_graph.h>
#include <dgl/zerocopy_serializer.h>
#include <dmlc/memory_io.h>
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <vector>
#include "../../src/graph/heterograph.h"
#include "../../src/graph/unit_graph.h"
#include "./common.h"
#ifndef _WIN32
using namespace dgl;
using namespace dgl::aten;
using namespace dmlc;
// Function to convert an idarray to string
std::string IdArrayToStr(IdArray arr) {
arr = arr.CopyTo(DLContext{kDLCPU, 0});
int64_t len = arr->shape[0];
std::ostringstream oss;
oss << "(" << len << ")[";
if (arr->dtype.bits == 32) {
int32_t *data = static_cast<int32_t *>(arr->data);
for (int64_t i = 0; i < len; ++i) {
oss << data[i] << " ";
}
} else {
int64_t *data = static_cast<int64_t *>(arr->data);
for (int64_t i = 0; i < len; ++i) {
oss << data[i] << " ";
}
}
oss << "]";
return oss.str();
}
TEST(ZeroCopySerialize, NDArray) {
auto tensor1 = VecToIdArray<int64_t>({1, 2, 5, 3});
auto tensor2 = VecToIdArray<int64_t>({6, 6, 5, 7});
std::string nonzerocopy_blob;
dmlc::MemoryStringStream ifs(&nonzerocopy_blob);
static_cast<dmlc::Stream *>(&ifs)->Write(tensor1);
static_cast<dmlc::Stream *>(&ifs)->Write(tensor2);
std::string zerocopy_blob;
StringStreamWithBuffer zc_write_strm(&zerocopy_blob);
static_cast<dmlc::Stream *>(&zc_write_strm)->Write(tensor1);
static_cast<dmlc::Stream *>(&zc_write_strm)->Write(tensor2);
EXPECT_EQ(nonzerocopy_blob.size() - zerocopy_blob.size(), 126)
<< "Invalid save";
std::vector<void *> new_ptr_list;
// Use memcpy to mimic remote machine reconstruction
for (auto ptr : zc_write_strm.buffer_list()) {
auto new_ptr = malloc(ptr.size);
memcpy(new_ptr, ptr.data, ptr.size);
new_ptr_list.emplace_back(new_ptr);
}
NDArray loadtensor1, loadtensor2;
StringStreamWithBuffer zc_read_strm(&zerocopy_blob, new_ptr_list);
static_cast<dmlc::Stream *>(&zc_read_strm)->Read(&loadtensor1);
static_cast<dmlc::Stream *>(&zc_read_strm)->Read(&loadtensor2);
}
TEST(ZeroCopySerialize, SharedMem) {
auto tensor1 = VecToIdArray<int64_t>({1, 2, 5, 3});
DLDataType dtype = {kDLInt, 64, 1};
std::vector<int64_t> shape{4};
DLContext cpu_ctx = {kDLCPU, 0};
auto shared_tensor =
NDArray::EmptyShared("test", shape, dtype, cpu_ctx, true);
shared_tensor.CopyFrom(tensor1);
std::string nonzerocopy_blob;
dmlc::MemoryStringStream ifs(&nonzerocopy_blob);
static_cast<dmlc::Stream *>(&ifs)->Write(shared_tensor);
std::string zerocopy_blob;
StringStreamWithBuffer zc_write_strm(&zerocopy_blob, false);
static_cast<dmlc::Stream *>(&zc_write_strm)->Write(shared_tensor);
EXPECT_EQ(nonzerocopy_blob.size() - zerocopy_blob.size(), 51)
<< "Invalid save";
NDArray loadtensor1, loadtensor2;
StringStreamWithBuffer zc_read_strm(&zerocopy_blob, false);
static_cast<dmlc::Stream *>(&zc_read_strm)->Read(&loadtensor1);
}
TEST(ZeroCopySerialize, HeteroGraph) {
auto src = VecToIdArray<int64_t>({1, 2, 5, 3});
auto dst = VecToIdArray<int64_t>({1, 6, 2, 6});
auto mg1 = dgl::UnitGraph::CreateFromCOO(2, 9, 8, src, dst);
src = VecToIdArray<int64_t>({6, 2, 5, 1, 8});
dst = VecToIdArray<int64_t>({5, 2, 4, 8, 0});
auto mg2 = dgl::UnitGraph::CreateFromCOO(1, 9, 9, src, dst);
std::vector<HeteroGraphPtr> relgraphs;
relgraphs.push_back(mg1);
relgraphs.push_back(mg2);
src = VecToIdArray<int64_t>({0, 0});
dst = VecToIdArray<int64_t>({1, 0});
auto meta_gptr = ImmutableGraph::CreateFromCOO(3, src, dst);
auto hrptr = std::make_shared<HeteroGraph>(meta_gptr, relgraphs);
std::string nonzerocopy_blob;
dmlc::MemoryStringStream ifs(&nonzerocopy_blob);
static_cast<dmlc::Stream *>(&ifs)->Write(hrptr);
std::string zerocopy_blob;
StringStreamWithBuffer zc_write_strm(&zerocopy_blob, true);
static_cast<dmlc::Stream *>(&zc_write_strm)->Write(hrptr);
EXPECT_EQ(nonzerocopy_blob.size() - zerocopy_blob.size(), 745)
<< "Invalid save";
std::vector<void *> new_ptr_list;
// Use memcpy to mimic remote machine reconstruction
for (auto ptr : zc_write_strm.buffer_list()) {
auto new_ptr = malloc(ptr.size);
memcpy(new_ptr, ptr.data, ptr.size);
new_ptr_list.emplace_back(new_ptr);
}
auto gptr = dgl::Serializer::make_shared<HeteroGraph>();
StringStreamWithBuffer zc_read_strm(&zerocopy_blob, new_ptr_list);
static_cast<dmlc::Stream *>(&zc_read_strm)->Read(&gptr);
EXPECT_EQ(gptr->NumVertices(0), 9);
EXPECT_EQ(gptr->NumVertices(1), 8);
}
#endif // _WIN32
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment