Unverified Commit 99bc3ab8 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

Use private buffer instead of global buffer (#511)

parent 8c79885d
......@@ -20,9 +20,6 @@ using dgl::runtime::NDArray;
namespace dgl {
namespace network {
static char* SEND_BUFFER = nullptr;
static char* RECV_BUFFER = nullptr;
// Wrapper for Send api
static void SendData(network::Sender* sender,
const char* data,
......@@ -46,12 +43,13 @@ static void RecvData(network::Receiver* receiver,
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Sender* sender = new network::SocketSender();
try {
SEND_BUFFER = new char[kMaxBufferSize];
char* buffer = new char[kMaxBufferSize];
sender->SetBuffer(buffer);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize;
}
network::Sender* sender = new network::SocketSender();
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(sender);
*rv = chandle;
});
......@@ -61,7 +59,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
CommunicatorHandle chandle = args[0];
network::Sender* sender = static_cast<network::Sender*>(chandle);
sender->Finalize();
delete [] SEND_BUFFER;
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
......@@ -96,10 +93,11 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR();
// Write control message
*SEND_BUFFER = CONTROL_NODEFLOW;
char* buffer = sender->GetBuffer();
*buffer = CONTROL_NODEFLOW;
// Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph(
SEND_BUFFER+sizeof(CONTROL_NODEFLOW),
buffer+sizeof(CONTROL_NODEFLOW),
csr,
node_mapping,
edge_mapping,
......@@ -108,7 +106,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
CHECK_GT(data_size, 0);
data_size += sizeof(CONTROL_NODEFLOW);
// Send msg via network
SendData(sender, SEND_BUFFER, data_size, recv_id);
SendData(sender, buffer, data_size, recv_id);
});
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
......@@ -116,19 +114,21 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
network::Sender* sender = static_cast<network::Sender*>(chandle);
*SEND_BUFFER = CONTROL_END_SIGNAL;
char* buffer = sender->GetBuffer();
*buffer = CONTROL_END_SIGNAL;
// Send msg via network
SendData(sender, SEND_BUFFER, sizeof(CONTROL_END_SIGNAL), recv_id);
SendData(sender, buffer, sizeof(CONTROL_END_SIGNAL), recv_id);
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Receiver* receiver = new network::SocketReceiver();
try {
RECV_BUFFER = new char[kMaxBufferSize];
char* buffer = new char[kMaxBufferSize];
receiver->SetBuffer(buffer);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
}
network::Receiver* receiver = new network::SocketReceiver();
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver);
*rv = chandle;
});
......@@ -138,7 +138,6 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize();
delete [] RECV_BUFFER;
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
......@@ -156,13 +155,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
// Recv data from network
RecvData(receiver, RECV_BUFFER, kMaxBufferSize);
int control = *RECV_BUFFER;
char* buffer = receiver->GetBuffer();
RecvData(receiver, buffer, kMaxBufferSize);
int control = *buffer;
if (control == CONTROL_NODEFLOW) {
NodeFlow* nf = new NodeFlow();
ImmutableGraph::CSR::Ptr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(RECV_BUFFER+sizeof(CONTROL_NODEFLOW),
network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW),
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
......
......@@ -52,6 +52,17 @@ class Sender {
* \brief Finalize Sender
*/
virtual void Finalize() = 0;
/*!
* \brief Get data buffer
* \return buffer pointer
*/
virtual char* GetBuffer() = 0;
/*!
* \brief Set data buffer
*/
virtual void SetBuffer(char* buffer) = 0;
};
/*!
......@@ -90,6 +101,17 @@ class Receiver {
* \brief Finalize Receiver
*/
virtual void Finalize() = 0;
/*!
* \brief Get data buffer
* \return buffer pointer
*/
virtual char* GetBuffer() = 0;
/*!
* \brief Set data buffer
*/
virtual void SetBuffer(char* buffer) = 0;
};
} // namespace network
......
......@@ -93,6 +93,15 @@ void SocketSender::Finalize() {
client = nullptr;
}
}
delete buffer_;
}
char* SocketSender::GetBuffer() {
return buffer_;
}
void SocketSender::SetBuffer(char* buffer) {
buffer_ = buffer;
}
bool SocketReceiver::Wait(const char* ip,
......@@ -190,6 +199,15 @@ void SocketReceiver::Finalize() {
socket_[i] = nullptr;
}
}
delete buffer_;
}
char* SocketReceiver::GetBuffer() {
return buffer_;
}
void SocketReceiver::SetBuffer(char* buffer) {
buffer_ = buffer;
}
} // namespace network
......
......@@ -71,6 +71,17 @@ class SocketSender : public Sender {
*/
void Finalize();
/*!
* \brief Get data buffer
* \return buffer pointer
*/
char* GetBuffer();
/*!
* \brief Set data buffer
*/
void SetBuffer(char* buffer);
private:
/*!
* \brief socket map
......@@ -81,6 +92,11 @@ class SocketSender : public Sender {
* \brief receiver address map
*/
std::unordered_map<int, Addr> receiver_addr_map_;
/*!
* \brief data buffer
*/
char* buffer_;
};
/*!
......@@ -118,6 +134,17 @@ class SocketReceiver : public Receiver {
*/
void Finalize();
/*!
* \brief Get data buffer
* \return buffer pointer
*/
char* GetBuffer();
/*!
* \brief Set data buffer
*/
void SetBuffer(char* buffer);
private:
/*!
* \brief number of sender
......@@ -144,6 +171,11 @@ class SocketReceiver : public Receiver {
*/
MessageQueue* queue_;
/*!
* \brief data buffer
*/
char* buffer_;
/*!
* \brief Process received message in independent threads
* \param socket new accpeted socket
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment