Unverified Commit 99bc3ab8 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

Use private buffer instead of global buffer (#511)

parent 8c79885d
...@@ -20,9 +20,6 @@ using dgl::runtime::NDArray; ...@@ -20,9 +20,6 @@ using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace network { namespace network {
static char* SEND_BUFFER = nullptr;
static char* RECV_BUFFER = nullptr;
// Wrapper for Send api // Wrapper for Send api
static void SendData(network::Sender* sender, static void SendData(network::Sender* sender,
const char* data, const char* data,
...@@ -46,12 +43,13 @@ static void RecvData(network::Receiver* receiver, ...@@ -46,12 +43,13 @@ static void RecvData(network::Receiver* receiver,
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Sender* sender = new network::SocketSender();
try { try {
SEND_BUFFER = new char[kMaxBufferSize]; char* buffer = new char[kMaxBufferSize];
sender->SetBuffer(buffer);
} catch (const std::bad_alloc&) { } catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize; LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize;
} }
network::Sender* sender = new network::SocketSender();
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(sender); CommunicatorHandle chandle = static_cast<CommunicatorHandle>(sender);
*rv = chandle; *rv = chandle;
}); });
...@@ -61,7 +59,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender") ...@@ -61,7 +59,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
sender->Finalize(); sender->Finalize();
delete [] SEND_BUFFER;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
...@@ -96,10 +93,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph") ...@@ -96,10 +93,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR(); auto csr = ptr->GetInCSR();
// Write control message // Write control message
*SEND_BUFFER = CONTROL_NODEFLOW; char* buffer = sender->GetBuffer();
*buffer = CONTROL_NODEFLOW;
// Serialize nodeflow to data buffer // Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph( int64_t data_size = network::SerializeSampledSubgraph(
SEND_BUFFER+sizeof(CONTROL_NODEFLOW), buffer+sizeof(CONTROL_NODEFLOW),
csr, csr,
node_mapping, node_mapping,
edge_mapping, edge_mapping,
...@@ -108,7 +106,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph") ...@@ -108,7 +106,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
CHECK_GT(data_size, 0); CHECK_GT(data_size, 0);
data_size += sizeof(CONTROL_NODEFLOW); data_size += sizeof(CONTROL_NODEFLOW);
// Send msg via network // Send msg via network
SendData(sender, SEND_BUFFER, data_size, recv_id); SendData(sender, buffer, data_size, recv_id);
}); });
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal") DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
...@@ -116,19 +114,21 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal") ...@@ -116,19 +114,21 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
int recv_id = args[1]; int recv_id = args[1];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
*SEND_BUFFER = CONTROL_END_SIGNAL; char* buffer = sender->GetBuffer();
*buffer = CONTROL_END_SIGNAL;
// Send msg via network // Send msg via network
SendData(sender, SEND_BUFFER, sizeof(CONTROL_END_SIGNAL), recv_id); SendData(sender, buffer, sizeof(CONTROL_END_SIGNAL), recv_id);
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Receiver* receiver = new network::SocketReceiver();
try { try {
RECV_BUFFER = new char[kMaxBufferSize]; char* buffer = new char[kMaxBufferSize];
receiver->SetBuffer(buffer);
} catch (const std::bad_alloc&) { } catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize; LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
} }
network::Receiver* receiver = new network::SocketReceiver();
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver); CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver);
*rv = chandle; *rv = chandle;
}); });
...@@ -138,7 +138,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver") ...@@ -138,7 +138,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize(); receiver->Finalize();
delete [] RECV_BUFFER;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait") DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
...@@ -156,13 +155,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph") ...@@ -156,13 +155,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
// Recv data from network // Recv data from network
RecvData(receiver, RECV_BUFFER, kMaxBufferSize); char* buffer = receiver->GetBuffer();
int control = *RECV_BUFFER; RecvData(receiver, buffer, kMaxBufferSize);
int control = *buffer;
if (control == CONTROL_NODEFLOW) { if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow(); NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr; ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer // Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(RECV_BUFFER+sizeof(CONTROL_NODEFLOW), network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW),
&(csr), &(csr),
&(nf->node_mapping), &(nf->node_mapping),
&(nf->edge_mapping), &(nf->edge_mapping),
......
...@@ -52,6 +52,17 @@ class Sender { ...@@ -52,6 +52,17 @@ class Sender {
* \brief Finalize Sender * \brief Finalize Sender
*/ */
virtual void Finalize() = 0; virtual void Finalize() = 0;
/*!
* \brief Get data buffer
* \return buffer pointer
*/
virtual char* GetBuffer() = 0;
/*!
* \brief Set data buffer
*/
virtual void SetBuffer(char* buffer) = 0;
}; };
/*! /*!
...@@ -90,6 +101,17 @@ class Receiver { ...@@ -90,6 +101,17 @@ class Receiver {
* \brief Finalize Receiver * \brief Finalize Receiver
*/ */
virtual void Finalize() = 0; virtual void Finalize() = 0;
/*!
* \brief Get data buffer
* \return buffer pointer
*/
virtual char* GetBuffer() = 0;
/*!
* \brief Set data buffer
*/
virtual void SetBuffer(char* buffer) = 0;
}; };
} // namespace network } // namespace network
......
...@@ -93,6 +93,15 @@ void SocketSender::Finalize() { ...@@ -93,6 +93,15 @@ void SocketSender::Finalize() {
client = nullptr; client = nullptr;
} }
} }
delete buffer_;
}
char* SocketSender::GetBuffer() {
return buffer_;
}
void SocketSender::SetBuffer(char* buffer) {
buffer_ = buffer;
} }
bool SocketReceiver::Wait(const char* ip, bool SocketReceiver::Wait(const char* ip,
...@@ -190,6 +199,15 @@ void SocketReceiver::Finalize() { ...@@ -190,6 +199,15 @@ void SocketReceiver::Finalize() {
socket_[i] = nullptr; socket_[i] = nullptr;
} }
} }
delete buffer_;
}
char* SocketReceiver::GetBuffer() {
return buffer_;
}
void SocketReceiver::SetBuffer(char* buffer) {
buffer_ = buffer;
} }
} // namespace network } // namespace network
......
...@@ -71,6 +71,17 @@ class SocketSender : public Sender { ...@@ -71,6 +71,17 @@ class SocketSender : public Sender {
*/ */
void Finalize(); void Finalize();
/*!
* \brief Get data buffer
* \return buffer pointer
*/
char* GetBuffer();
/*!
* \brief Set data buffer
*/
void SetBuffer(char* buffer);
private: private:
/*! /*!
* \brief socket map * \brief socket map
...@@ -81,6 +92,11 @@ class SocketSender : public Sender { ...@@ -81,6 +92,11 @@ class SocketSender : public Sender {
* \brief receiver address map * \brief receiver address map
*/ */
std::unordered_map<int, Addr> receiver_addr_map_; std::unordered_map<int, Addr> receiver_addr_map_;
/*!
* \brief data buffer
*/
char* buffer_;
}; };
/*! /*!
...@@ -118,6 +134,17 @@ class SocketReceiver : public Receiver { ...@@ -118,6 +134,17 @@ class SocketReceiver : public Receiver {
*/ */
void Finalize(); void Finalize();
/*!
* \brief Get data buffer
* \return buffer pointer
*/
char* GetBuffer();
/*!
* \brief Set data buffer
*/
void SetBuffer(char* buffer);
private: private:
/*! /*!
* \brief number of sender * \brief number of sender
...@@ -144,6 +171,11 @@ class SocketReceiver : public Receiver { ...@@ -144,6 +171,11 @@ class SocketReceiver : public Receiver {
*/ */
MessageQueue* queue_; MessageQueue* queue_;
/*!
* \brief data buffer
*/
char* buffer_;
/*! /*!
* \brief Process received message in independent threads * \brief Process received message in independent threads
* \param socket new accpeted socket * \param socket new accpeted socket
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment