Unverified Commit 5d494c62 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[RPC] add C++ RPC infrastructure and distributed sampler (#465)

* add C++ rpc infrastructure and distributed sampler

* update

* update lint

* update lint

* update lint

* update

* update

* update

* updare

* update

* update

* update

* update serialize and unittest

* update serialize

* lint

* update

* update

* update

* update

* update

* update

* update unittest

* put Finalize() to __del__

* update unittest

* update

* delete buffer in Finalize

* update unittest

* update unittest

* update unittest

* update unittest

* update

* update

* fix small bug

* windows socket impl

* update API

* fix bug in serialize

* fix bug in serialzie

* set parent graph

* update

* update

* update

* update

* update

* update

* fix lint

* fix lint

* fix

* fix windows compilation error

* fix windows error

* change API to lower-case

* update test

* fix typo

* update

* add SamplerPool

* add SamplerPool

* update

* update test

* update

* update

* update

* update

* add example

* update

* update
parent 6066fee9
...@@ -26,6 +26,7 @@ if(MSVC) ...@@ -26,6 +26,7 @@ if(MSVC)
add_definitions(-DWIN32_LEAN_AND_MEAN) add_definitions(-DWIN32_LEAN_AND_MEAN)
add_definitions(-D_CRT_SECURE_NO_WARNINGS) add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS) add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-DNOMINMAX)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
...@@ -52,7 +53,7 @@ else(MSVC) ...@@ -52,7 +53,7 @@ else(MSVC)
endif(MSVC) endif(MSVC)
# Source file lists # Source file lists
file(GLOB CORE_SRCS src/graph/*.cc src/*.cc src/scheduler/*.cc) file(GLOB CORE_SRCS src/graph/*.cc src/graph/network/* src/*.cc src/scheduler/*.cc)
file(GLOB RUNTIME_SRCS src/runtime/*.cc) file(GLOB RUNTIME_SRCS src/runtime/*.cc)
......
"""DGL root package.""" """DGL root package."""
# Windows compatibility
# This initializes Winsock and performs cleanup at termination as required
import socket
from . import function from . import function
from . import nn from . import nn
from . import contrib from . import contrib
......
from .sampler import NeighborSampler, LayerSampler from .sampler import NeighborSampler, LayerSampler
from .randomwalk import * from .randomwalk import *
from .dis_sampler import SamplerSender, SamplerReceiver
from .dis_sampler import SamplerPool
\ No newline at end of file
# This file contains DGL distributed samplers APIs.
from ...network import _send_subgraph, _recv_subgraph
from ...network import _create_sampler_sender, _create_sampler_receiver
from ...network import _finalize_sampler_sender, _finalize_sampler_receiver
from multiprocessing import Pool
from abc import ABCMeta, abstractmethod
class SamplerPool(object):
"""SamplerPool is an abstract class, in which the worker method
should be implemented by users. SamplerPool will fork() N (N = num_worker)
child processes, and each process will perform worker() method independently.
Note that, the fork() API will use shared memory for N process and the OS will
perfrom copy-on-write only when developers write that piece of memory.
Users can use this class like this:
class MySamplerPool(SamplerPool):
def worker(self):
# Do anything here #
if __name__ == '__main__':
pool = MySamplerPool()
pool.start(5) # Start 5 processes
Parameters
----------
num_worker : int
number of worker (child process)
"""
__metaclass__ = ABCMeta
def start(self, num_worker):
p = Pool()
for i in range(num_worker):
print("Start child process %d ..." % i)
p.apply_async(self.worker)
# Waiting for all subprocesses done ...
p.close()
p.join()
@abstractmethod
def worker(self):
pass
class SamplerSender(object):
"""Sender of DGL distributed sampler.
Users use SamplerSender class to send sampled
subgraph (NodeFlow) to remote trainer. Note that, SamplerSender
class will try to connect to SamplerReceiver in a loop until the
SamplerReceiver started.
Parameters
----------
ip : str
ip address of remote trainer machine
port : int
port of remote trainer machine
"""
def __init__(self, ip, port):
self._ip = ip
self._port = port
self._sender = _create_sampler_sender(ip, port)
def __del__(self):
"""Finalize Sender
"""
# _finalize_sampler_sender will send a special message
# to tell the remote trainer machine that it has finished its job.
_finalize_sampler_sender(self._sender)
def send(self, nodeflow):
"""Send sampled subgraph (NodeFlow) to remote trainer.
Parameters
----------
nodeflow : NodeFlow
sampled NodeFlow object
"""
_send_subgraph(self._sender, nodeflow)
class SamplerReceiver(object):
"""Receiver of DGL distributed sampler.
Users use SamplerReceiver class to receive sampled
subgraph (NodeFlow) from remote samplers. Note that SamplerReceiver
can receive messages from multiple senders concurrently, by given
the num_sender parameter, and only when all senders connect to SamplerReceiver,
the SamplerReceiver can start its job.
Parameters
----------
ip : str
ip address of current trainer machine
port : int
port of current trainer machine
num_sender : int
total number of sampler nodes, use 1 by default
"""
def __init__(self, ip, port, num_sender=1):
self._ip = ip
self._port = port
self._num_sender = num_sender
self._receiver = _create_sampler_receiver(ip, port, num_sender)
def __del__(self):
"""Finalize Receiver
_finalize_sampler_receiver method will clean up the
back-end threads started by the SamplerReceiver.
"""
_finalize_sampler_receiver(self._receiver)
def recv(self, graph):
"""Receive a NodeFlow object from remote sampler.
Parameters
----------
graph : DGLGraph
The parent graph
Returns
-------
NodeFlow
received NodeFlow object
"""
return _recv_subgraph(self._receiver, graph)
"""DGL Distributed Training Infrastructure."""
from __future__ import absolute_import
from ._ffi.function import _init_api
from .nodeflow import NodeFlow
from .utils import unwrap_to_ptr_list
from . import utils
_init_api("dgl.network")
############################# Distributed Sampler #############################
def _create_sampler_sender(ip_addr, port):
"""Create a sender communicator via C socket.
Parameters
----------
ip_addr : str
ip address of remote trainer
port : int
port of remote trainer
"""
return _CAPI_DGLSenderCreate(ip_addr, port)
def _create_sampler_receiver(ip_addr, port, num_sender):
"""Create a receiver communicator via C socket.
Parameters
----------
ip_addr : str
ip address of remote trainer
port : int
listen port of remote trainer
num_sender : int
total number of sampler nodes
"""
return _CAPI_DGLReceiverCreate(ip_addr, port, num_sender)
def _send_subgraph(sender, nodeflow):
"""Send sampled subgraph (Nodeflow) to remote trainer.
Parameters
----------
sender : ctypes.c_void_p
C sender handle
nodeflow : NodeFlow
NodeFlow object
"""
graph_handle = nodeflow._graph._handle
node_mapping = nodeflow._node_mapping.todgltensor()
edge_mapping = nodeflow._edge_mapping.todgltensor()
# Can we convert NDArray to tensor directly, instead of using toindex()?
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendSubgraph(sender,
graph_handle,
node_mapping,
edge_mapping,
layers_offsets,
flows_offsets)
def _recv_subgraph(receiver, graph):
"""Receive sampled subgraph (NodeFlow) from remote sampler.
Parameters
----------
receiver : ctypes.c_void_p
C receiver handle
graph : DGLGraph
The parent graph
Returns
-------
NodeFlow
Sampled NodeFlow object
"""
# hdl is a list of ptr
hdl = unwrap_to_ptr_list(_CAPI_ReceiverRecvSubgraph(receiver))
return NodeFlow(graph, hdl[0])
def _finalize_sampler_sender(sender):
"""Finalize Sender communicator
Parameters
----------
sender : ctypes.c_void_p
C sender handle
"""
_CAPI_DGLFinalizeCommunicator(sender)
def _finalize_sampler_receiver(receiver):
"""Finalize Receiver communicator
Parameters
----------
receiver : ctypes.c_void_p
C receiver handle
"""
_CAPI_DGLFinalizeCommunicator(receiver)
...@@ -17,6 +17,8 @@ namespace dgl { ...@@ -17,6 +17,8 @@ namespace dgl {
// Graph handler type // Graph handler type
typedef void* GraphHandle; typedef void* GraphHandle;
// Communicator handler type
typedef void* CommunicatorHandle;
/*! /*!
* \brief Convert the given DLTensor to DLManagedTensor. * \brief Convert the given DLTensor to DLManagedTensor.
* *
......
/*!
* Copyright (c) 2018 by Contributors
* \file graph/network.cc
* \brief DGL networking related APIs
*/
#include "./network.h"
#include "./network/communicator.h"
#include "./network/socket_communicator.h"
#include "./network/serialize.h"
#include "../c_api_common.h"
using dgl::runtime::DGLArgs;
using dgl::runtime::DGLArgValue;
using dgl::runtime::DGLRetValue;
using dgl::runtime::PackedFunc;
using dgl::runtime::NDArray;
namespace dgl {
namespace network {
static char* sender_data_buffer = nullptr;
static char* recv_data_buffer = nullptr;
///////////////////////// Distributed Sampler /////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
network::Communicator* comm = new network::SocketCommunicator();
if (comm->Initialize(IS_SENDER, ip.c_str(), port) == false) {
LOG(FATAL) << "Initialize network communicator (sender) error.";
}
try {
sender_data_buffer = new char[kMaxBufferSize];
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize;
}
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(comm);
*rv = chandle;
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
int port = args[1];
int num_sender = args[2];
network::Communicator* comm = new network::SocketCommunicator();
if (comm->Initialize(IS_RECEIVER, ip.c_str(), port, num_sender, kQueueSize) == false) {
LOG(FATAL) << "Initialize network communicator (receiver) error.";
}
try {
recv_data_buffer = new char[kMaxBufferSize];
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
}
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(comm);
*rv = chandle;
});
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
GraphHandle ghandle = args[1];
const IdArray node_mapping = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[2]));
const IdArray edge_mapping = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[3]));
const IdArray layer_offsets = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[4]));
const IdArray flow_offsets = IdArray::FromDLPack(CreateTmpDLManagedTensor(args[5]));
ImmutableGraph *ptr = static_cast<ImmutableGraph*>(ghandle);
network::Communicator* comm = static_cast<network::Communicator*>(chandle);
auto csr = ptr->GetInCSR();
// Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph(
sender_data_buffer,
csr,
node_mapping,
edge_mapping,
layer_offsets,
flow_offsets);
CHECK_GT(data_size, 0);
// Send msg via network
int64_t size = comm->Send(sender_data_buffer, data_size);
if (size <= 0) {
LOG(ERROR) << "Send message error (size: " << size << ")";
}
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Communicator* comm = static_cast<network::Communicator*>(chandle);
// Recv data from network
int64_t size = comm->Receive(recv_data_buffer, kMaxBufferSize);
if (size <= 0) {
LOG(ERROR) << "Receive error: (size: " << size << ")";
}
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(recv_data_buffer,
&csr,
&(nf->node_mapping),
&(nf->edge_mapping),
&(nf->layer_offsets),
&(nf->flow_offsets));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr, false));
std::vector<NodeFlow*> subgs(1);
subgs[0] = nf;
*rv = WrapVectorReturn(subgs);
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeCommunicator")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Communicator* comm = static_cast<network::Communicator*>(chandle);
comm->Finalize();
delete [] sender_data_buffer;
delete [] recv_data_buffer;
});
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2018 by Contributors
* \file graph/network.h
* \brief DGL networking related APIs
*/
#ifndef DGL_GRAPH_NETWORK_H_
#define DGL_GRAPH_NETWORK_H_
#include <dmlc/logging.h>
namespace dgl {
namespace network {
#define IS_SENDER true
#define IS_RECEIVER false
// TODO(chao): make these numbers configurable
// Each single message cannot larger than 300 MB
const int64_t kMaxBufferSize = 300 * 1024 * 2014;
// Size of message queue is 1 GB
const int64_t kQueueSize = 1024 * 1024 * 1024;
// Maximal try count of connection
const int kMaxTryCount = 500;
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_H_
/*!
* Copyright (c) 2019 by Contributors
* \file communicator.h
* \brief Communicator for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_COMMUNICATOR_H_
#include <string>
namespace dgl {
namespace network {
/*!
* \brief Communicator for DGL distributed training.
*
* Communicator is a set of interface for network communication, which
* can be implemented by real network libraries, such as grpc, mpi, as well
* as raw socket. There has two types of Communicator, one is Sender
* (is_sender = true), and another is Receiver. For Sender, it can send binary
* data to remote Receiver node. For Receiver, it can listen on a specified
* endpoint and receive the binary data sent from Sender node. Note that, a
* receiver node can recv messages from multiple senders concurrently.
*/
class Communicator {
public:
virtual ~Communicator() {}
/*!
* \brief Initialize Communicator
* \param is_sender true for sender and false for receiver
* \param ip ip address
* \param port end port
* (e.g. "168.123.2.43:50051"). For Receiver, this address identifies
* the local listening endpoint (e.g. "0.0.0.0:50051").
* \param num_sender number of senders, only used for receiver.
* \param queue_size the size of message queue, only for receiver.
* \return true for success and false for error
*/
virtual bool Initialize(bool is_sender,
const char* ip,
int port,
int num_sender = 1,
int64_t queue_size = 5 * 1024 * 1024) = 0;
/*!
* \brief Send message to receiver node
* \param src data pointer
* \param size data size
* \return bytes send
* > 0 : bytes send
* - 1 : error
*/
virtual int64_t Send(char* src, int64_t size) = 0;
/*!
* \brief Receive mesesage from sender node, we
* actually reading data from local message queue.
* \param dest destination data pointer
* \param max_size maximal data size
* \return bytes received
* > 0 : bytes received
* - 1 : error
*/
virtual int64_t Receive(char* dest, int64_t max_size) = 0;
/*!
* \brief Finalize the Communicator class
*/
virtual void Finalize() = 0;
};
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_COMMUNICATOR_H_
/*!
* Copyright (c) 2019 by Contributors
* \file msg_queue.cc
* \brief Message queue for DGL distributed training.
*/
#include <dmlc/logging.h>
#include <cstring>
#include "msg_queue.h"
namespace dgl {
namespace network {
using std::string;
MessageQueue::MessageQueue(int64_t queue_size, int num_producers) {
CHECK_LT(0, queue_size);
try {
queue_ = new char[queue_size];
} catch(const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for message queue.";
}
memset(queue_, '\0', queue_size);
queue_size_ = queue_size;
free_size_ = queue_size;
write_pointer_ = 0;
num_producers_ = num_producers;
}
MessageQueue::~MessageQueue() {
std::lock_guard<std::mutex> lock(mutex_);
if (nullptr != queue_) {
delete [] queue_;
queue_ = nullptr;
}
}
int64_t MessageQueue::Add(const char* src, int64_t size, bool is_blocking) {
// check if message is too long to fit into the queue
if (size > queue_size_) {
LOG(ERROR) << "Message is larger than the queue.";
return -1;
}
if (size <= 0) {
LOG(ERROR) << "Message size (" << size << ") is negative or zero.";
return -1;
}
std::unique_lock<std::mutex> lock(mutex_);
if (finished_producers_.size() >= num_producers_) {
LOG(ERROR) << "Can't add to buffer when flag_no_more is set.";
return -1;
}
if (size > free_size_ && !is_blocking) {
LOG(WARNING) << "Queue is full and message lost.";
return 0;
}
cond_not_full_.wait(lock, [&]() {
return size <= free_size_;
});
// Write data into buffer:
// If there has enough space on tail of buffer, just append data
// else, write till in the end of buffer and return to head of buffer
message_positions_.push(std::make_pair(write_pointer_, size));
free_size_ -= size;
if (write_pointer_ + size <= queue_size_) {
memcpy(&queue_[write_pointer_], src, size);
write_pointer_ += size;
if (write_pointer_ == queue_size_) {
write_pointer_ = 0;
}
} else {
int64_t size_partial = queue_size_ - write_pointer_;
memcpy(&queue_[write_pointer_], src, size_partial);
memcpy(queue_, &src[size_partial], size - size_partial);
write_pointer_ = size - size_partial;
}
// not empty signal
cond_not_empty_.notify_one();
return size;
}
int64_t MessageQueue::Add(const string &src, bool is_blocking) {
return Add(src.data(), src.size(), is_blocking);
}
int64_t MessageQueue::Remove(char *dest, int64_t max_size, bool is_blocking) {
int64_t retval;
std::unique_lock<std::mutex> lock(mutex_);
if (message_positions_.empty()) {
if (!is_blocking) {
return 0;
}
if (finished_producers_.size() >= num_producers_) {
return 0;
}
}
cond_not_empty_.wait(lock, [this] {
return !message_positions_.empty() || exit_flag_.load();
});
if (finished_producers_.size() >= num_producers_) {
return 0;
}
MessagePosition & pos = message_positions_.front();
// check if message is too long
if (pos.second > max_size) {
LOG(ERROR) << "Message size exceeds limit, information lost.";
retval = -1;
} else {
// read from buffer:
// if this message stores in consecutive memory, just read
// else, read from buffer tail then return to the head
if (pos.first + pos.second <= queue_size_) {
memcpy(dest, &queue_[pos.first], pos.second);
} else {
int64_t size_partial = queue_size_ - pos.first;
memcpy(dest, &queue_[pos.first], size_partial);
memcpy(&dest[size_partial], queue_, pos.second - size_partial);
}
retval = pos.second;
}
free_size_ += pos.second;
message_positions_.pop();
cond_not_full_.notify_one();
return retval;
}
int64_t MessageQueue::Remove(string *dest, bool is_blocking) {
int64_t retval;
std::unique_lock<std::mutex> lock(mutex_);
if (message_positions_.empty()) {
if (!is_blocking) {
return 0;
}
if (finished_producers_.size() >= num_producers_) {
return 0;
}
}
cond_not_empty_.wait(lock, [this] {
return !message_positions_.empty() || exit_flag_.load();
});
MessagePosition & pos = message_positions_.front();
// read from buffer:
// if this message stores in consecutive memory, just read
// else, read from buffer tail then return to the head
if (pos.first + pos.second <= queue_size_) {
dest->assign(&queue_[pos.first], pos.second);
} else {
int64_t size_partial = queue_size_ - pos.first;
dest->assign(&queue_[pos.first], size_partial);
dest->append(queue_, pos.second - size_partial);
}
retval = pos.second;
free_size_ += pos.second;
message_positions_.pop();
cond_not_full_.notify_one();
return retval;
}
void MessageQueue::Signal(int producer_id) {
std::lock_guard<std::mutex> lock(mutex_);
finished_producers_.insert(producer_id);
// if all producers have finished, consumers should be
// waken up to get this signal
if (finished_producers_.size() >= num_producers_) {
exit_flag_.store(true);
cond_not_empty_.notify_all();
}
}
bool MessageQueue::EmptyAndNoMoreAdd() const {
std::lock_guard<std::mutex> lock(mutex_);
return message_positions_.size() == 0 &&
finished_producers_.size() >= num_producers_;
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file msg_queue.h
* \brief Message queue for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#define DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#include <queue>
#include <set>
#include <string>
#include <utility> // for pair
#include <mutex>
#include <condition_variable>
#include <atomic>
namespace dgl {
namespace network {
/*!
* \brief Message Queue for DGL distributed training.
*
* MessageQueue is a circle queue for using the ring-buffer in a
* producer/consumer model. It supports one or more producer
* threads and one or more consumer threads. Producers invokes Add()
* to push data elements into the queue, and consumers invokes
* Remove() to pop data elements. Add() and Remove() use two condition
* variables to synchronize producers and consumers. Each producer invokes
* Signal(producer_id) to claim that it is about to finish, where
* producer_id is an integer uniquely identify a producer thread. This
* signaling mechanism prevents consumers from waiting after all producers
* have finished their jobs.
*
*/
class MessageQueue {
public:
/*!
* \brief MessageQueue constructor
* \param queue_size size of message queue
* \param num_producers number of producers, use 1 by default
*/
MessageQueue(int64_t queue_size /* in bytes */,
int num_producers = 1);
/*!
* \brief MessageQueue deconstructor
*/
~MessageQueue();
/*!
* \brief Add data to the message queue
* \param src The data pointer
* \param size The size of data
* \param is_blocking Block function if cannot add, else return
* \return bytes added to the queue
* > 0 : size of message
* = 0 : no enough space for this message (when is_blocking = false)
* - 1 : error
*/
int64_t Add(const char* src, int64_t size, bool is_blocking = true);
/*!
* \brief Add data to the message queue
* \param src The data string
* \param is_blocking Block function if cannot add, else return
* \return bytes added to queue
* > 0 : size of message
* = 0 : no enough space for this message (when is_blocking = false)
* - 1 : error
*/
int64_t Add(const std::string& src, bool is_blocking = true);
/*!
* \brief Remove message from the queue
* \param dest The destination data pointer
* \param max_size Maximal size of data
* \param is_blocking Block function if cannot remove, else return
* \return bytes removed from queue
* > 0 : size of message
* = 0 : queue is empty
* - 1 : error
*/
int64_t Remove(char *dest, int64_t max_size, bool is_blocking = true);
/*!
* \brief Remove message from the queue
* \param dest The destination data string
* \param is_blocking Block function if cannot remove, else return
* \return bytes removed from queue
* > 0 : size of message
* = 0 : queue is empty
* - 1 : error
*/
int64_t Remove(std::string *dest, bool is_blocking = true);
/*!
* \brief Signal that producer producer_id will no longer produce anything
* \param producer_id An integer uniquely to identify a producer thread
*/
void Signal(int producer_id);
/*!
* \return true if queue is empty and all num_producers have signaled.
*/
bool EmptyAndNoMoreAdd() const;
protected:
typedef std::pair<int64_t /* message_start_position in queue_ */,
int64_t /* message_length */> MessagePosition;
/*!
* \brief Pointer to the queue
*/
char* queue_;
/*!
* \brief Size of the queue in bytes
*/
int64_t queue_size_;
/*!
* \brief Free size of the queue
*/
int64_t free_size_;
/*!
* \brief Location in queue_ for where to write the next element
* Note that we do not need read_pointer since all messages were indexed
* by message_postions_, and the first element in message_position_
* denotes where we should read
*/
int64_t write_pointer_;
/*!
* \brief Used to check all producers will no longer produce anything
*/
int num_producers_;
/*!
* \brief Messages in the queue
*/
std::queue<MessagePosition> message_positions_;
/*!
* \brief Store finished producer id
*/
std::set<int /* producer_id */> finished_producers_;
/*!
* \brief Condition when consumer should wait
*/
std::condition_variable cond_not_full_;
/*!
* \brief Condition when producer should wait
*/
std::condition_variable cond_not_empty_;
/*!
* \brief Signal for exit wait
*/
std::atomic<bool> exit_flag_{false};
/*!
* \brief Protect all above data and conditions
*/
mutable std::mutex mutex_;
};
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_MSG_QUEUE_H_
/*!
* Copyright (c) 2019 by Contributors
* \file serialize.cc
* \brief Serialization for DGL distributed training.
*/
#include "serialize.h"
#include <dmlc/logging.h>
#include <dgl/immutable_graph.h>
#include <cstring>
#include "../network.h"
namespace dgl {
namespace network {
const int kNumTensor = 7; // We need to serialize 7 conponents (tensor) here
int64_t SerializeSampledSubgraph(char* data,
const ImmutableGraph::CSR::Ptr csr,
const IdArray& node_mapping,
const IdArray& edge_mapping,
const IdArray& layer_offsets,
const IdArray& flow_offsets) {
int64_t total_size = 0;
// For each component, we first write its size at the
// begining of the buffer and then write its binary data
int64_t node_mapping_size = node_mapping->shape[0] * sizeof(dgl_id_t);
int64_t edge_mapping_size = edge_mapping->shape[0] * sizeof(dgl_id_t);
int64_t layer_offsets_size = layer_offsets->shape[0] * sizeof(dgl_id_t);
int64_t flow_offsets_size = flow_offsets->shape[0] * sizeof(dgl_id_t);
int64_t indptr_size = csr->indptr.size() * sizeof(int64_t);
int64_t indices_size = csr->indices.size() * sizeof(dgl_id_t);
int64_t edge_ids_size = csr->edge_ids.size() * sizeof(dgl_id_t);
total_size += node_mapping_size;
total_size += edge_mapping_size;
total_size += layer_offsets_size;
total_size += flow_offsets_size;
total_size += indptr_size;
total_size += indices_size;
total_size += edge_ids_size;
total_size += kNumTensor * sizeof(int64_t);
if (total_size > kMaxBufferSize) {
LOG(FATAL) << "Message size: (" << total_size
<< ") is larger than buffer size: ("
<< kMaxBufferSize << ")";
}
// Write binary data to buffer
char* data_ptr = data;
dgl_id_t* node_map_data = static_cast<dgl_id_t*>(node_mapping->data);
dgl_id_t* edge_map_data = static_cast<dgl_id_t*>(edge_mapping->data);
dgl_id_t* layer_off_data = static_cast<dgl_id_t*>(layer_offsets->data);
dgl_id_t* flow_off_data = static_cast<dgl_id_t*>(flow_offsets->data);
int64_t* indptr = static_cast<int64_t*>(csr->indptr.data());
dgl_id_t* indices = static_cast<dgl_id_t*>(csr->indices.data());
dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr->edge_ids.data());
// node_mapping
*(reinterpret_cast<int64_t*>(data_ptr)) = node_mapping_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, node_map_data, node_mapping_size);
data_ptr += node_mapping_size;
// layer_offsets
*(reinterpret_cast<int64_t*>(data_ptr)) = layer_offsets_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, layer_off_data, layer_offsets_size);
data_ptr += layer_offsets_size;
// flow_offsets
*(reinterpret_cast<int64_t*>(data_ptr)) = flow_offsets_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, flow_off_data, flow_offsets_size);
data_ptr += flow_offsets_size;
// edge_mapping
*(reinterpret_cast<int64_t*>(data_ptr)) = edge_mapping_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, edge_map_data, edge_mapping_size);
data_ptr += edge_mapping_size;
// indices (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = indices_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, indices, indices_size);
data_ptr += indices_size;
// edge_ids (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = edge_ids_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, edge_ids, edge_ids_size);
data_ptr += edge_ids_size;
// indptr (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = indptr_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, indptr, indptr_size);
data_ptr += indptr_size;
return total_size;
}
void DeserializeSampledSubgraph(char* data,
ImmutableGraph::CSR::Ptr* csr,
IdArray* node_mapping,
IdArray* edge_mapping,
IdArray* layer_offsets,
IdArray* flow_offsets) {
// For each component, we first read its size at the
// begining of the buffer and then read its binary data
char* data_ptr = data;
// node_mapping
int64_t tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_vertices = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*node_mapping = IdArray::Empty({static_cast<int64_t>(num_vertices)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* node_map_data = static_cast<dgl_id_t*>((*node_mapping)->data);
memcpy(node_map_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// layer offsets
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_hops_add_one = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*layer_offsets = IdArray::Empty({static_cast<int64_t>(num_hops_add_one)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* layer_off_data = static_cast<dgl_id_t*>((*layer_offsets)->data);
memcpy(layer_off_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// flow offsets
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_hops = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*flow_offsets = IdArray::Empty({static_cast<int64_t>(num_hops)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* flow_off_data = static_cast<dgl_id_t*>((*flow_offsets)->data);
memcpy(flow_off_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// edge_mapping
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_edges = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*edge_mapping = IdArray::Empty({static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* edge_mapping_data = static_cast<dgl_id_t*>((*edge_mapping)->data);
memcpy(edge_mapping_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// Construct sub_csr_graph
*csr = std::make_shared<ImmutableGraph::CSR>(num_vertices, num_edges);
(*csr)->indices.resize(num_edges);
(*csr)->edge_ids.resize(num_edges);
// indices (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* col_list_out = (*csr)->indices.data();
memcpy(col_list_out, data_ptr, tensor_size);
data_ptr += tensor_size;
// edge_ids (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* edge_ids = (*csr)->edge_ids.data();
memcpy(edge_ids, data_ptr, tensor_size);
data_ptr += tensor_size;
// indptr (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
int64_t* indptr_out = (*csr)->indptr.data();
memcpy(indptr_out, data_ptr, tensor_size);
data_ptr += tensor_size;
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file serialize.h
* \brief Serialization for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_SERIALIZE_H_
#define DGL_GRAPH_NETWORK_SERIALIZE_H_
#include <dgl/sampler.h>
#include <dgl/immutable_graph.h>
namespace dgl {
namespace network {
/*!
* \brief Serialize sampled subgraph to binary data
* \param data pointer of data buffer
* \param csr subgraph csr
* \param node_mapping node mapping in NodeFlowIndex
* \param edge_mapping edge mapping in NodeFlowIndex
* \param layer_offsets layer offsets in NodeFlowIndex
* \param flow_offsets flow offsets in NodeFlowIndex
* \return the total size of the serialized binary data
*/
int64_t SerializeSampledSubgraph(char* data,
const ImmutableGraph::CSR::Ptr csr,
const IdArray& node_mapping,
const IdArray& edge_mapping,
const IdArray& layer_offsets,
const IdArray& flow_offsets);
/*!
* \brief Deserialize sampled subgraph from binary data
* \param data pointer of data buffer
* \param csr subgraph csr
* \param node_mapping node mapping in NodeFlowIndex
* \param edge_mapping edge mapping in NodeFlowIndex
* \param layer_offsets layer offsets in NodeFlowIndex
* \param flow_offsets flow offsets in NodeFlowIndex
*/
void DeserializeSampledSubgraph(char* data,
ImmutableGraph::CSR::Ptr* csr,
IdArray* node_mapping,
IdArray* edge_mapping,
IdArray* layer_offsets,
IdArray* flow_offsets);
// TODO(chao): we can add compression and decompression method here
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_SERIALIZE_H_
/*!
* Copyright (c) 2019 by Contributors
* \file communicator.cc
* \brief SocketCommunicator for DGL distributed training.
*/
#include <dmlc/logging.h>
#include "socket_communicator.h"
#include "../../c_api_common.h"
#include "../network.h"
#ifdef _WIN32
#include <windows.h>
#else // !_WIN32
#include <unistd.h>
#endif // _WIN32
namespace dgl {
namespace network {
const int kTimeOut = 10; // 10 minutes for socket timeout
const int kMaxConnection = 1024; // 1024 maximal socket connection
bool SocketCommunicator::Initialize(bool is_sender,
const char* ip,
int port,
int num_sender,
int64_t queue_size) {
if (is_sender) {
is_sender_ = true;
return InitSender(ip, port);
} else {
is_sender_ = false;
return InitReceiver(ip, port, num_sender, queue_size);
}
}
bool SocketCommunicator::InitSender(const char* ip, int port) {
// Sender only has a client socket
socket_.resize(1);
socket_[0] = new TCPSocket();
TCPSocket* client = socket_[0];
bool bo = false;
int try_count = 0;
// Connect to server
while (bo == false && try_count < kMaxTryCount) {
if (client->Connect(ip, port)) {
LOG(INFO) << "Connected to " << ip << ":" << port;
return true;
} else {
LOG(ERROR) << "Cannot connect to " << ip << ":" << port
<< ", try again ...";
bo = false;
try_count++;
#ifdef _WIN32
Sleep(1);
#else // !_WIN32
sleep(1);
#endif // _WIN32
}
}
return false;
}
bool SocketCommunicator::InitReceiver(const char* ip,
int port,
int num_sender,
int64_t queue_size) {
CHECK_GE(num_sender, 1);
CHECK_GT(queue_size, 0);
// Init message queue
num_sender_ = num_sender;
queue_size_ = queue_size;
queue_ = new MessageQueue(queue_size_, num_sender_);
// Init socket, and socket_[0] is the server socket
socket_.resize(num_sender+1);
thread_.resize(num_sender);
socket_[0] = new TCPSocket();
TCPSocket* server = socket_[0];
server->SetTimeout(kTimeOut * 60 * 1000); // millsec
// Bind socket
if (server->Bind(ip, port) == false) {
LOG(ERROR) << "Cannot bind to " << ip << ":" << port;
return false;
}
LOG(INFO) << "Bind to " << ip << ":" << port;
// Listen
if (server->Listen(kMaxConnection) == false) {
LOG(ERROR) << "Cannot listen on " << ip << ":" << port;
return false;
}
LOG(INFO) << "Listen on " << ip << ":" << port << ", wait sender connect ...";
// Accept all sender sockets
std::string accept_ip;
int accept_port;
for (int i = 1; i <= num_sender_; ++i) {
socket_[i] = new TCPSocket();
if (server->Accept(socket_[i], &accept_ip, &accept_port) == false) {
LOG(ERROR) << "Error on accept socket.";
return false;
}
// new thread for the socket
thread_[i-1] = new std::thread(MsgHandler, socket_[i], queue_);
LOG(INFO) << "Accept new sender: " << accept_ip << ":" << accept_port;
}
return true;
}
void SocketCommunicator::MsgHandler(TCPSocket* socket, MessageQueue* queue) {
char* buffer = new char[kMaxBufferSize];
for (;;) {
// First recv the size
int64_t received_bytes = 0;
int64_t data_size = 0;
while (received_bytes < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - received_bytes;
int64_t tmp = socket->Receive(
reinterpret_cast<char*>(&data_size)+received_bytes,
max_len);
received_bytes += tmp;
}
if (data_size <= 0) {
LOG(INFO) << "Socket finish job";
break;
}
// Then recv the data
received_bytes = 0;
while (received_bytes < data_size) {
int64_t max_len = data_size - received_bytes;
int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
received_bytes += tmp;
}
queue->Add(buffer, data_size);
}
delete [] buffer;
}
void SocketCommunicator::Finalize() {
if (is_sender_) {
FinalizeSender();
} else {
FinalizeReceiver();
}
}
void SocketCommunicator::FinalizeSender() {
// We send a size = -1 signal to notify
// receiver to finish its job
if (socket_[0] != nullptr) {
int64_t size = -1;
int64_t sent_bytes = 0;
while (sent_bytes < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = socket_[0]->Send(
reinterpret_cast<char*>(&size)+sent_bytes,
max_len);
sent_bytes += tmp;
}
socket_[0]->Close();
LOG(INFO) << "Close sender socket.";
delete socket_[0];
socket_[0] = nullptr;
}
}
void SocketCommunicator::FinalizeReceiver() {
for (int i = 0; i <= num_sender_; ++i) {
if (socket_[i] != nullptr) {
socket_[i]->Close();
delete socket_[i];
socket_[i] = nullptr;
}
}
}
int64_t SocketCommunicator::Send(char* src, int64_t size) {
if (!is_sender_) {
LOG(ERROR) << "Receiver cannot invoke send() API.";
return -1;
}
TCPSocket* client = socket_[0];
// First sent the size of data
int64_t sent_bytes = 0;
while (sent_bytes < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = client->Send(
reinterpret_cast<char*>(&size)+sent_bytes,
max_len);
sent_bytes += tmp;
}
// Then send the data
sent_bytes = 0;
while (sent_bytes < size) {
int64_t max_len = size - sent_bytes;
int64_t tmp = client->Send(src+sent_bytes, max_len);
sent_bytes += tmp;
}
return size + sizeof(int64_t);
}
int64_t SocketCommunicator::Receive(char* dest, int64_t max_size) {
if (is_sender_) {
LOG(ERROR) << "Sender cannot invoke Receive() API.";
return -1;
}
// Get message from the message queue
return queue_->Remove(dest, max_size);
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file communicator.h
* \brief SocketCommunicator for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_
#include <thread>
#include <vector>
#include <string>
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
namespace dgl {
namespace network {
using dgl::network::MessageQueue;
using dgl::network::TCPSocket;
/*!
* \brief Implementation of Communicator class with TCP socket.
*/
class SocketCommunicator : public Communicator {
public:
/*!
* \brief Initialize Communicator
* \param is_sender true for sender and false for receiver
* \param ip ip address
* \param port end port
* (e.g. "168.123.2.43:50051"). For Receiver, this address identifies
* the local listening endpoint (e.g. "0.0.0.0:50051").
* \param num_sender number of senders, only used for receiver.
* \param queue_size the size of message queue, only for receiver.
* \return true for success and false for error
*/
bool Initialize(bool is_sender,
const char* ip,
int port,
int num_sender = 1,
int64_t queue_size = 5 * 1024 * 1024);
/*!
* \brief Send message to receiver node
* \param src data pointer
* \param size data size
* \return bytes send
* > 0 : bytes send
* - 1 : error
*/
int64_t Send(char* src, int64_t size);
/*!
* \brief Receive mesesage from sender node, we
* actually reading data from local message queue.
* \param dest destination data pointer
* \param max_size maximal data size
* \return bytes received
* > 0 : bytes received
* - 1 : error
*/
int64_t Receive(char* dest, int64_t max_size);
/*!
* \brief Finalize the SocketCommunicator class
*/
void Finalize();
private:
/*!
* \brief Is a sender or reciever node?
*/
bool is_sender_;
/*!
* \brief number of sender
*/
int num_sender_;
/*!
* \brief maximal size of message queue
*/
int64_t queue_size_;
/*!
* \brief socket list
*/
std::vector<TCPSocket*> socket_;
/*!
* \brief Thread pool for socket connection
*/
std::vector<std::thread*> thread_;
/*!
* \brief Message queue for communicator
*/
MessageQueue* queue_;
/*!
* \brief Initalize sender node
* \param ip receiver ip address
* \param port receiver port
* \return true for success and false for error
*/
bool InitSender(const char* ip, int port);
/*!
* \brief Initialize receiver node
* \param ip receiver ip address
* \param port receiver port
* \param num_sender number of sender
* \param queue_size size of message queue
* \return true for success and false for error
*/
bool InitReceiver(const char* ip,
int port,
int num_sender,
int64_t queue_size);
/*!
* \brief Finalize sender node
*/
void FinalizeSender();
/*!
* \brief Finalize receiver node
*/
void FinalizeReceiver();
/*!
* \brief Process received message in independent threads
* \param socket new accpeted socket
* \param queue message queue
*/
static void MsgHandler(TCPSocket* socket, MessageQueue* queue);
};
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_SOCKET_COMMUNICATOR_H_
/*!
* Copyright (c) 2019 by Contributors
* \file tcp_socket.cc
* \brief TCP socket for DGL distributed training.
*/
#include "tcp_socket.h"
#include <dmlc/logging.h>
#ifndef _WIN32
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <unistd.h>
#endif // !_WIN32
namespace dgl {
namespace network {
typedef struct sockaddr_in SAI;
typedef struct sockaddr SA;
TCPSocket::TCPSocket() {
// init socket
socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (socket_ < 0) {
LOG(FATAL) << "Can't create new socket.";
}
}
TCPSocket::~TCPSocket() {
Close();
}
bool TCPSocket::Connect(const char * ip, int port) {
SAI sa_server;
sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port);
if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
0 <= connect(socket_, reinterpret_cast<SA*>(&sa_server),
sizeof(sa_server))) {
return true;
}
LOG(ERROR) << "Failed connect to " << ip << ":" << port;
return false;
}
bool TCPSocket::Bind(const char * ip, int port) {
SAI sa_server;
sa_server.sin_family = AF_INET;
sa_server.sin_port = htons(port);
if (0 < inet_pton(AF_INET, ip, &sa_server.sin_addr) &&
0 <= bind(socket_, reinterpret_cast<SA*>(&sa_server),
sizeof(sa_server))) {
return true;
}
LOG(ERROR) << "Failed bind on " << ip << ":" << port;
return false;
}
bool TCPSocket::Listen(int max_connection) {
if (0 <= listen(socket_, max_connection)) {
return true;
}
LOG(ERROR) << "Failed listen on socket fd: " << socket_;
return false;
}
bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
int sock_client;
SAI sa_client;
socklen_t len = sizeof(sa_client);
sock_client = accept(socket_, reinterpret_cast<SA*>(&sa_client), &len);
if (sock_client < 0) {
LOG(ERROR) << "Failed accept connection on " << *ip << ":" << *port;
return false;
}
char tmp[INET_ADDRSTRLEN];
const char * ip_client = inet_ntop(AF_INET,
&sa_client.sin_addr,
tmp,
sizeof(tmp));
CHECK(ip_client != nullptr);
ip->assign(ip_client);
*port = ntohs(sa_client.sin_port);
socket->socket_ = sock_client;
return true;
}
#ifdef _WIN32
bool TCPSocket::SetBlocking(bool flag) {
int result;
u_long argp = flag ? 1 : 0;
// XXX Non-blocking Windows Sockets apparently has tons of issues:
// http://www.sockets.com/winsock.htm#Overview_BlockingNonBlocking
// Since SetBlocking() is not used at all, I'm leaving a default
// implementation here. But be warned that this is not fully tested.
if ((result = ioctlsocket(socket_, FIONBIO, &argp)) != NO_ERROR) {
LOG(ERROR) << "Failed to set socket status.";
return false;
}
return true;
}
#else // !_WIN32
bool TCPSocket::SetBlocking(bool flag) {
int opts;
if ((opts = fcntl(socket_, F_GETFL)) < 0) {
LOG(ERROR) << "Failed to get socket status.";
return false;
}
if (flag) {
opts |= O_NONBLOCK;
} else {
opts &= ~O_NONBLOCK;
}
if (fcntl(socket_, F_SETFL, opts) < 0) {
LOG(ERROR) << "Failed to set socket status.";
return false;
}
return true;
}
#endif // _WIN32
void TCPSocket::SetTimeout(int timeout) {
setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char*>(&timeout), sizeof(timeout));
}
bool TCPSocket::ShutDown(int ways) {
return 0 == shutdown(socket_, ways);
}
void TCPSocket::Close() {
if (socket_ >= 0) {
#ifdef _WIN32
CHECK_EQ(0, closesocket(socket_));
#else // !_WIN32
CHECK_EQ(0, close(socket_));
#endif // _WIN32
socket_ = -1;
}
}
int64_t TCPSocket::Send(const char * data, int64_t len_data) {
return send(socket_, data, len_data, 0);
}
int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
return recv(socket_, buffer, size_buffer, 0);
}
int TCPSocket::Socket() const {
return socket_;
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file tcp_socket.h
* \brief TCP socket for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_TCP_SOCKET_H_
#define DGL_GRAPH_NETWORK_TCP_SOCKET_H_
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma comment(lib, "Ws2_32.lib")
#else // !_WIN32
#include <sys/socket.h>
#endif // _WIN32
#include <string>
namespace dgl {
namespace network {
/*!
* \brief TCPSocket is a simple wrapper around a socket.
* It supports only TCP connections.
*/
class TCPSocket {
public:
/*!
* \brief TCPSocket constructor
*/
TCPSocket();
/*!
* \brief TCPSocket deconstructor
*/
~TCPSocket();
/*!
* \brief Connect to a given server address
* \param ip ip address
* \param port end port
* \return true for success and false for failure
*/
bool Connect(const char * ip, int port);
/*!
* \brief Bind on the given IP and PORT
* \param ip ip address
* \param port end port
* \return true for success and false for failure
*/
bool Bind(const char * ip, int port);
/*!
* \brief listen for remote connection
* \param max_connection maximal connection
* \return true for success and false for failure
*/
bool Listen(int max_connection);
/*!
* \brief wait doe a new connection
* \param socket new SOCKET will be stored to socket
* \param ip_client new IP will be stored to ip_client
* \param port_client new PORT will be stored to port_client
* \return true for success and false for failure
*/
bool Accept(TCPSocket * socket,
std::string * ip_client,
int * port_client);
/*!
* \brief SetBlocking() is needed refering to this example of epoll:
* http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html
* \param flag flag for blocking
* \return true for success and false for failure
*/
bool SetBlocking(bool flag);
/*!
* \brief Set timeout for socket
* \param timeout millsec timeout
*/
void SetTimeout(int timeout);
/*!
* \brief Shut down one or both halves of the connection.
* \param ways ways for shutdown
* If ways is SHUT_RD, further receives are disallowed.
* If ways is SHUT_WR, further sends are disallowed.
* If ways is SHUT_RDWR, further sends and receives are disallowed.
* \return true for success and false for failure
*/
bool ShutDown(int ways);
/*!
* \brief close socket.
*/
void Close();
/*!
* \brief Send data.
* \param data data for sending
* \param len_data length of data
* \return return number of bytes sent if OK, -1 on error
*/
int64_t Send(const char * data, int64_t len_data);
/*!
* \brief Receive data.
* \param buffer buffer for receving
* \param size_buffer size of buffer
* \return return number of bytes received if OK, -1 on error
*/
int64_t Receive(char * buffer, int64_t size_buffer);
/*!
* \brief Get socket's file descriptor
* \return socket's file descriptor
*/
int Socket() const;
private:
/*!
* \brief socket's file descriptor
*/
int socket_;
};
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_TCP_SOCKET_H_
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <dgl/sampler.h> #include <dgl/sampler.h>
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dgl/immutable_graph.h> #include <dgl/immutable_graph.h>
#include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <cstdlib> #include <cstdlib>
#include <cmath> #include <cmath>
......
from dgl import backend as F
import numpy as np
import scipy as sp
import dgl
from dgl import utils
import os
import time
def generate_rand_graph(n):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
return dgl.DGLGraph(arr, readonly=True)
def start_trainer():
g = generate_rand_graph(100)
recv = dgl.contrib.sampling.SamplerReceiver(ip="127.0.0.1", port=50051)
subg = recv.recv(g)
seed_ids = subg.layer_parent_nid(-1)
assert len(seed_ids) == 1
src, dst, eid = g.in_edges(seed_ids, form='all')
assert subg.number_of_nodes() == len(src) + 1
assert subg.number_of_edges() == len(src)
assert seed_ids == subg.layer_parent_nid(-1)
child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
assert F.array_equal(child_src, subg.layer_nid(0))
src1 = subg.map_to_parent_nid(child_src)
assert F.array_equal(src1, src)
time.sleep(3) # wait all senders to finalize their jobs
def start_sampler():
g = generate_rand_graph(100)
sender = dgl.contrib.sampling.SamplerSender(ip="127.0.0.1", port=50051)
for i, subg in enumerate(dgl.contrib.sampling.NeighborSampler(
g, 1, 100, neighbor_type='in', num_workers=4)):
sender.send(subg)
break
time.sleep(1)
if __name__ == '__main__':
pid = os.fork()
if pid == 0:
start_trainer()
else:
time.sleep(1)
start_sampler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment