Unverified Commit 8058f1c5 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Fix] Update inner API of distributed sampler (#478)

* update inner API of distributed sampler

* update
parent da3ab84c
# This file contains DGL distributed samplers APIs. # This file contains DGL distributed samplers APIs.
from ...network import _send_subgraph, _recv_subgraph from ...network import _send_subgraph, _recv_subgraph
from ...network import _create_sampler_sender, _create_sampler_receiver from ...network import _create_sender, _create_receiver
from ...network import _finalize_sampler_sender, _finalize_sampler_receiver from ...network import _finalize_sender, _finalize_receiver
from multiprocessing import Pool from multiprocessing import Pool
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
...@@ -68,14 +68,14 @@ class SamplerSender(object): ...@@ -68,14 +68,14 @@ class SamplerSender(object):
def __init__(self, ip, port): def __init__(self, ip, port):
self._ip = ip self._ip = ip
self._port = port self._port = port
self._sender = _create_sampler_sender(ip, port) self._sender = _create_sender(ip, port)
def __del__(self): def __del__(self):
"""Finalize Sender """Finalize Sender
""" """
# _finalize_sampler_sender will send a special message # _finalize_sender will send a special message
# to tell the remote trainer machine that it has finished its job. # to tell the remote trainer machine that it has finished its job.
_finalize_sampler_sender(self._sender) _finalize_sender(self._sender)
def send(self, nodeflow): def send(self, nodeflow):
"""Send sampled subgraph (NodeFlow) to remote trainer. """Send sampled subgraph (NodeFlow) to remote trainer.
...@@ -109,7 +109,7 @@ class SamplerReceiver(object): ...@@ -109,7 +109,7 @@ class SamplerReceiver(object):
self._ip = ip self._ip = ip
self._port = port self._port = port
self._num_sender = num_sender self._num_sender = num_sender
self._receiver = _create_sampler_receiver(ip, port, num_sender) self._receiver = _create_receiver(ip, port, num_sender)
def __del__(self): def __del__(self):
"""Finalize Receiver """Finalize Receiver
...@@ -117,7 +117,7 @@ class SamplerReceiver(object): ...@@ -117,7 +117,7 @@ class SamplerReceiver(object):
_finalize_sampler_receiver method will clean up the _finalize_sampler_receiver method will clean up the
back-end threads started by the SamplerReceiver. back-end threads started by the SamplerReceiver.
""" """
_finalize_sampler_receiver(self._receiver) _finalize_receiver(self._receiver)
def recv(self, graph): def recv(self, graph):
"""Receive a NodeFlow object from remote sampler. """Receive a NodeFlow object from remote sampler.
......
...@@ -8,9 +8,7 @@ from . import utils ...@@ -8,9 +8,7 @@ from . import utils
_init_api("dgl.network") _init_api("dgl.network")
############################# Distributed Sampler ############################# def _create_sender(ip_addr, port):
def _create_sampler_sender(ip_addr, port):
"""Create a sender communicator via C socket. """Create a sender communicator via C socket.
Parameters Parameters
...@@ -22,7 +20,7 @@ def _create_sampler_sender(ip_addr, port): ...@@ -22,7 +20,7 @@ def _create_sampler_sender(ip_addr, port):
""" """
return _CAPI_DGLSenderCreate(ip_addr, port) return _CAPI_DGLSenderCreate(ip_addr, port)
def _create_sampler_receiver(ip_addr, port, num_sender): def _create_receiver(ip_addr, port, num_sender):
"""Create a receiver communicator via C socket. """Create a receiver communicator via C socket.
Parameters Parameters
...@@ -78,7 +76,7 @@ def _recv_subgraph(receiver, graph): ...@@ -78,7 +76,7 @@ def _recv_subgraph(receiver, graph):
hdl = unwrap_to_ptr_list(_CAPI_ReceiverRecvSubgraph(receiver)) hdl = unwrap_to_ptr_list(_CAPI_ReceiverRecvSubgraph(receiver))
return NodeFlow(graph, hdl[0]) return NodeFlow(graph, hdl[0])
def _finalize_sampler_sender(sender): def _finalize_sender(sender):
"""Finalize Sender communicator """Finalize Sender communicator
Parameters Parameters
...@@ -88,7 +86,7 @@ def _finalize_sampler_sender(sender): ...@@ -88,7 +86,7 @@ def _finalize_sampler_sender(sender):
""" """
_CAPI_DGLFinalizeCommunicator(sender) _CAPI_DGLFinalizeCommunicator(sender)
def _finalize_sampler_receiver(receiver): def _finalize_receiver(receiver):
"""Finalize Receiver communicator """Finalize Receiver communicator
Parameters Parameters
......
...@@ -20,11 +20,6 @@ using dgl::runtime::NDArray; ...@@ -20,11 +20,6 @@ using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace network { namespace network {
static char* sender_data_buffer = nullptr;
static char* recv_data_buffer = nullptr;
///////////////////////// Distributed Sampler /////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string ip = args[0]; std::string ip = args[0];
...@@ -34,7 +29,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") ...@@ -34,7 +29,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
LOG(FATAL) << "Initialize network communicator (sender) error."; LOG(FATAL) << "Initialize network communicator (sender) error.";
} }
try { try {
sender_data_buffer = new char[kMaxBufferSize]; comm->SetBuffer(new char[kMaxBufferSize]);
} catch (const std::bad_alloc&) { } catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize; LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize;
} }
...@@ -52,7 +47,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate") ...@@ -52,7 +47,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
LOG(FATAL) << "Initialize network communicator (receiver) error."; LOG(FATAL) << "Initialize network communicator (receiver) error.";
} }
try { try {
recv_data_buffer = new char[kMaxBufferSize]; comm->SetBuffer(new char[kMaxBufferSize]);
} catch (const std::bad_alloc&) { } catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize; LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
} }
...@@ -73,7 +68,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph") ...@@ -73,7 +68,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
auto csr = ptr->GetInCSR(); auto csr = ptr->GetInCSR();
// Serialize nodeflow to data buffer // Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph( int64_t data_size = network::SerializeSampledSubgraph(
sender_data_buffer, comm->GetBuffer(),
csr, csr,
node_mapping, node_mapping,
edge_mapping, edge_mapping,
...@@ -81,7 +76,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph") ...@@ -81,7 +76,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
flow_offsets); flow_offsets);
CHECK_GT(data_size, 0); CHECK_GT(data_size, 0);
// Send msg via network // Send msg via network
int64_t size = comm->Send(sender_data_buffer, data_size); int64_t size = comm->Send(comm->GetBuffer(), data_size);
if (size <= 0) { if (size <= 0) {
LOG(ERROR) << "Send message error (size: " << size << ")"; LOG(ERROR) << "Send message error (size: " << size << ")";
} }
...@@ -92,15 +87,15 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph") ...@@ -92,15 +87,15 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Communicator* comm = static_cast<network::Communicator*>(chandle); network::Communicator* comm = static_cast<network::Communicator*>(chandle);
// Recv data from network // Recv data from network
int64_t size = comm->Receive(recv_data_buffer, kMaxBufferSize); int64_t size = comm->Receive(comm->GetBuffer(), kMaxBufferSize);
if (size <= 0) { if (size <= 0) {
LOG(ERROR) << "Receive error: (size: " << size << ")"; LOG(ERROR) << "Receive error: (size: " << size << ")";
} }
NodeFlow* nf = new NodeFlow(); NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr; ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer // Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(recv_data_buffer, network::DeserializeSampledSubgraph(comm->GetBuffer(),
&csr, &(csr),
&(nf->node_mapping), &(nf->node_mapping),
&(nf->edge_mapping), &(nf->edge_mapping),
&(nf->layer_offsets), &(nf->layer_offsets),
...@@ -116,8 +111,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeCommunicator") ...@@ -116,8 +111,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeCommunicator")
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Communicator* comm = static_cast<network::Communicator*>(chandle); network::Communicator* comm = static_cast<network::Communicator*>(chandle);
comm->Finalize(); comm->Finalize();
delete [] sender_data_buffer;
delete [] recv_data_buffer;
}); });
} // namespace network } // namespace network
......
...@@ -67,6 +67,16 @@ class Communicator { ...@@ -67,6 +67,16 @@ class Communicator {
* \brief Finalize the Communicator class * \brief Finalize the Communicator class
*/ */
virtual void Finalize() = 0; virtual void Finalize() = 0;
/*!
* \brief Set pointer of memory buffer allocated for Communicator
*/
virtual void SetBuffer(char* buffer) = 0;
/*!
* \brief Get pointer of memory buffer allocated for Communicator
*/
virtual char* GetBuffer() = 0;
}; };
} // namespace network } // namespace network
......
...@@ -162,6 +162,9 @@ void SocketCommunicator::FinalizeSender() { ...@@ -162,6 +162,9 @@ void SocketCommunicator::FinalizeSender() {
delete socket_[0]; delete socket_[0];
socket_[0] = nullptr; socket_[0] = nullptr;
} }
if (buffer_ != nullptr) {
delete [] buffer_;
}
} }
void SocketCommunicator::FinalizeReceiver() { void SocketCommunicator::FinalizeReceiver() {
...@@ -209,5 +212,15 @@ int64_t SocketCommunicator::Receive(char* dest, int64_t max_size) { ...@@ -209,5 +212,15 @@ int64_t SocketCommunicator::Receive(char* dest, int64_t max_size) {
return queue_->Remove(dest, max_size); return queue_->Remove(dest, max_size);
} }
void SocketCommunicator::SetBuffer(char* buffer) {
// Set memory buffer allocated for current Communicator
buffer_ = buffer;
}
char* SocketCommunicator::GetBuffer() {
// Get memory buffer allocated for current Communicator
return buffer_;
}
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
...@@ -67,6 +67,16 @@ class SocketCommunicator : public Communicator { ...@@ -67,6 +67,16 @@ class SocketCommunicator : public Communicator {
*/ */
void Finalize(); void Finalize();
/*!
* \brief Set pointer of memory buffer allocated for Communicator
*/
void SetBuffer(char* buffer);
/*!
* \brief Get pointer of memory buffer allocated for Communicator
*/
char* GetBuffer();
private: private:
/*! /*!
* \brief Is a sender or reciever node? * \brief Is a sender or reciever node?
...@@ -98,6 +108,11 @@ class SocketCommunicator : public Communicator { ...@@ -98,6 +108,11 @@ class SocketCommunicator : public Communicator {
*/ */
MessageQueue* queue_; MessageQueue* queue_;
/*!
* \brief Memory buffer for communicator
*/
char* buffer_ = nullptr;
/*! /*!
* \brief Initalize sender node * \brief Initalize sender node
* \param ip receiver ip address * \param ip receiver ip address
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment