Unverified Commit 8058f1c5 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Fix] Update inner API of distributed sampler (#478)

* update inner API of distributed sampler

* update
parent da3ab84c
# This file contains DGL distributed samplers APIs.
from ...network import _send_subgraph, _recv_subgraph
from ...network import _create_sampler_sender, _create_sampler_receiver
from ...network import _finalize_sampler_sender, _finalize_sampler_receiver
from ...network import _create_sender, _create_receiver
from ...network import _finalize_sender, _finalize_receiver
from multiprocessing import Pool
from abc import ABCMeta, abstractmethod
......@@ -68,14 +68,14 @@ class SamplerSender(object):
def __init__(self, ip, port):
self._ip = ip
self._port = port
self._sender = _create_sampler_sender(ip, port)
self._sender = _create_sender(ip, port)
def __del__(self):
"""Finalize Sender
"""
# _finalize_sampler_sender will send a special message
# _finalize_sender will send a special message
# to tell the remote trainer machine that it has finished its job.
_finalize_sampler_sender(self._sender)
_finalize_sender(self._sender)
def send(self, nodeflow):
"""Send sampled subgraph (NodeFlow) to remote trainer.
......@@ -109,7 +109,7 @@ class SamplerReceiver(object):
self._ip = ip
self._port = port
self._num_sender = num_sender
self._receiver = _create_sampler_receiver(ip, port, num_sender)
self._receiver = _create_receiver(ip, port, num_sender)
def __del__(self):
"""Finalize Receiver
......@@ -117,7 +117,7 @@ class SamplerReceiver(object):
_finalize_sampler_receiver method will clean up the
back-end threads started by the SamplerReceiver.
"""
_finalize_sampler_receiver(self._receiver)
_finalize_receiver(self._receiver)
def recv(self, graph):
"""Receive a NodeFlow object from remote sampler.
......
......@@ -8,9 +8,7 @@ from . import utils
_init_api("dgl.network")
############################# Distributed Sampler #############################
def _create_sampler_sender(ip_addr, port):
def _create_sender(ip_addr, port):
"""Create a sender communicator via C socket.
Parameters
......@@ -22,7 +20,7 @@ def _create_sampler_sender(ip_addr, port):
"""
return _CAPI_DGLSenderCreate(ip_addr, port)
def _create_sampler_receiver(ip_addr, port, num_sender):
def _create_receiver(ip_addr, port, num_sender):
"""Create a receiver communicator via C socket.
Parameters
......@@ -78,7 +76,7 @@ def _recv_subgraph(receiver, graph):
hdl = unwrap_to_ptr_list(_CAPI_ReceiverRecvSubgraph(receiver))
return NodeFlow(graph, hdl[0])
def _finalize_sampler_sender(sender):
def _finalize_sender(sender):
"""Finalize Sender communicator
Parameters
......@@ -88,7 +86,7 @@ def _finalize_sampler_sender(sender):
"""
_CAPI_DGLFinalizeCommunicator(sender)
def _finalize_sampler_receiver(receiver):
def _finalize_receiver(receiver):
"""Finalize Receiver communicator
Parameters
......
......@@ -20,11 +20,6 @@ using dgl::runtime::NDArray;
namespace dgl {
namespace network {
static char* sender_data_buffer = nullptr;
static char* recv_data_buffer = nullptr;
///////////////////////// Distributed Sampler /////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0];
......@@ -34,7 +29,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
LOG(FATAL) << "Initialize network communicator (sender) error.";
}
try {
sender_data_buffer = new char[kMaxBufferSize];
comm->SetBuffer(new char[kMaxBufferSize]);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize;
}
......@@ -52,7 +47,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
LOG(FATAL) << "Initialize network communicator (receiver) error.";
}
try {
recv_data_buffer = new char[kMaxBufferSize];
comm->SetBuffer(new char[kMaxBufferSize]);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
}
......@@ -73,7 +68,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
auto csr = ptr->GetInCSR();
// Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph(
sender_data_buffer,
comm->GetBuffer(),
csr,
node_mapping,
edge_mapping,
......@@ -81,7 +76,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
flow_offsets);
CHECK_GT(data_size, 0);
// Send msg via network
int64_t size = comm->Send(sender_data_buffer, data_size);
int64_t size = comm->Send(comm->GetBuffer(), data_size);
if (size <= 0) {
LOG(ERROR) << "Send message error (size: " << size << ")";
}
......@@ -92,15 +87,15 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle chandle = args[0];
network::Communicator* comm = static_cast<network::Communicator*>(chandle);
// Recv data from network
int64_t size = comm->Receive(recv_data_buffer, kMaxBufferSize);
int64_t size = comm->Receive(comm->GetBuffer(), kMaxBufferSize);
if (size <= 0) {
LOG(ERROR) << "Receive error: (size: " << size << ")";
}
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(recv_data_buffer,
&csr,
network::DeserializeSampledSubgraph(comm->GetBuffer(),
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
&(nf->layer_offsets),
......@@ -116,8 +111,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeCommunicator")
CommunicatorHandle chandle = args[0];
network::Communicator* comm = static_cast<network::Communicator*>(chandle);
comm->Finalize();
delete [] sender_data_buffer;
delete [] recv_data_buffer;
});
} // namespace network
......
......@@ -67,6 +67,16 @@ class Communicator {
* \brief Finalize the Communicator class
*/
virtual void Finalize() = 0;
/*!
* \brief Set pointer of memory buffer allocated for Communicator
*/
virtual void SetBuffer(char* buffer) = 0;
/*!
* \brief Get pointer of memory buffer allocated for Communicator
*/
virtual char* GetBuffer() = 0;
};
} // namespace network
......
......@@ -162,6 +162,9 @@ void SocketCommunicator::FinalizeSender() {
delete socket_[0];
socket_[0] = nullptr;
}
if (buffer_ != nullptr) {
delete [] buffer_;
}
}
void SocketCommunicator::FinalizeReceiver() {
......@@ -209,5 +212,15 @@ int64_t SocketCommunicator::Receive(char* dest, int64_t max_size) {
return queue_->Remove(dest, max_size);
}
void SocketCommunicator::SetBuffer(char* buffer) {
// Set memory buffer allocated for current Communicator
buffer_ = buffer;
}
char* SocketCommunicator::GetBuffer() {
// Get memory buffer allocated for current Communicator
return buffer_;
}
} // namespace network
} // namespace dgl
......@@ -67,6 +67,16 @@ class SocketCommunicator : public Communicator {
*/
void Finalize();
/*!
* \brief Set pointer of memory buffer allocated for Communicator
*/
void SetBuffer(char* buffer);
/*!
* \brief Get pointer of memory buffer allocated for Communicator
*/
char* GetBuffer();
private:
/*!
* \brief Is a sender or reciever node?
......@@ -98,6 +108,11 @@ class SocketCommunicator : public Communicator {
*/
MessageQueue* queue_;
/*!
* \brief Memory buffer for communicator
*/
char* buffer_ = nullptr;
/*!
* \brief Initalize sender node
* \param ip receiver ip address
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment