Unverified Commit 747a8bee authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Fix] Hold NDArray reference during send() (#740)

* fix NDArray reference

* fix lint

* capture NDArray in the closure
parent 16061925
...@@ -215,50 +215,50 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow") ...@@ -215,50 +215,50 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow")
Message node_mapping_msg; Message node_mapping_msg;
node_mapping_msg.data = static_cast<char*>(node_mapping->data); node_mapping_msg.data = static_cast<char*>(node_mapping->data);
node_mapping_msg.size = node_mapping.GetSize(); node_mapping_msg.size = node_mapping.GetSize();
node_mapping_msg.aux_handler = &node_mapping; // capture the array in the closure
node_mapping_msg.deallocator = NDArrayDeleter; node_mapping_msg.deallocator = [node_mapping](Message*) {};
CHECK_NE(sender->Send(node_mapping_msg, recv_id), -1); CHECK_NE(sender->Send(node_mapping_msg, recv_id), -1);
// send edege_mapping // send edege_mapping
Message edge_mapping_msg; Message edge_mapping_msg;
edge_mapping_msg.data = static_cast<char*>(edge_mapping->data); edge_mapping_msg.data = static_cast<char*>(edge_mapping->data);
edge_mapping_msg.size = edge_mapping.GetSize(); edge_mapping_msg.size = edge_mapping.GetSize();
edge_mapping_msg.aux_handler = &edge_mapping; // capture the array in the closure
edge_mapping_msg.deallocator = NDArrayDeleter; edge_mapping_msg.deallocator = [edge_mapping](Message*) {};
CHECK_NE(sender->Send(edge_mapping_msg, recv_id), -1); CHECK_NE(sender->Send(edge_mapping_msg, recv_id), -1);
// send layer_offsets // send layer_offsets
Message layer_offsets_msg; Message layer_offsets_msg;
layer_offsets_msg.data = static_cast<char*>(layer_offsets->data); layer_offsets_msg.data = static_cast<char*>(layer_offsets->data);
layer_offsets_msg.size = layer_offsets.GetSize(); layer_offsets_msg.size = layer_offsets.GetSize();
layer_offsets_msg.aux_handler = &layer_offsets; // capture the array in the closure
layer_offsets_msg.deallocator = NDArrayDeleter; layer_offsets_msg.deallocator = [layer_offsets](Message*) {};
CHECK_NE(sender->Send(layer_offsets_msg, recv_id), -1); CHECK_NE(sender->Send(layer_offsets_msg, recv_id), -1);
// send flow_offset // send flow_offset
Message flow_offsets_msg; Message flow_offsets_msg;
flow_offsets_msg.data = static_cast<char*>(flow_offsets->data); flow_offsets_msg.data = static_cast<char*>(flow_offsets->data);
flow_offsets_msg.size = flow_offsets.GetSize(); flow_offsets_msg.size = flow_offsets.GetSize();
flow_offsets_msg.aux_handler = &flow_offsets; // capture the array in the closure
flow_offsets_msg.deallocator = NDArrayDeleter; flow_offsets_msg.deallocator = [flow_offsets](Message*) {};
CHECK_NE(sender->Send(flow_offsets_msg, recv_id), -1); CHECK_NE(sender->Send(flow_offsets_msg, recv_id), -1);
// send csr->indptr // send csr->indptr
Message indptr_msg; Message indptr_msg;
indptr_msg.data = static_cast<char*>(indptr->data); indptr_msg.data = static_cast<char*>(indptr->data);
indptr_msg.size = indptr.GetSize(); indptr_msg.size = indptr.GetSize();
indptr_msg.aux_handler = &indptr; // capture the array in the closure
indptr_msg.deallocator = NDArrayDeleter; indptr_msg.deallocator = [indptr](Message*) {};
CHECK_NE(sender->Send(indptr_msg, recv_id), -1); CHECK_NE(sender->Send(indptr_msg, recv_id), -1);
// send csr->indices // send csr->indices
Message indices_msg; Message indices_msg;
indices_msg.data = static_cast<char*>(indice->data); indices_msg.data = static_cast<char*>(indice->data);
indices_msg.size = indice.GetSize(); indices_msg.size = indice.GetSize();
indices_msg.aux_handler = &indice; // capture the array in the closure
indices_msg.deallocator = NDArrayDeleter; indices_msg.deallocator = [indice](Message*) {};
CHECK_NE(sender->Send(indices_msg, recv_id), -1); CHECK_NE(sender->Send(indices_msg, recv_id), -1);
// send csr->edge_ids // send csr->edge_ids
Message edge_ids_msg; Message edge_ids_msg;
edge_ids_msg.data = static_cast<char*>(csr->edge_ids()->data); edge_ids_msg.data = static_cast<char*>(edge_ids->data);
edge_ids_msg.size = csr->edge_ids().GetSize(); edge_ids_msg.size = edge_ids.GetSize();
edge_ids_msg.aux_handler = &edge_ids; // capture the array in the closure
edge_ids_msg.deallocator = NDArrayDeleter; edge_ids_msg.deallocator = [edge_ids](Message*) {};
CHECK_NE(sender->Send(edge_ids_msg, recv_id), -1); CHECK_NE(sender->Send(edge_ids_msg, recv_id), -1);
}); });
......
...@@ -24,13 +24,6 @@ namespace network { ...@@ -24,13 +24,6 @@ namespace network {
// TODO(chao): Make this number configurable // TODO(chao): Make this number configurable
const int64_t kQueueSize = 200 * 1024 * 1024; const int64_t kQueueSize = 200 * 1024 * 1024;
/*!
* \brief Free memory buffer of NodeFlow
*/
inline void NDArrayDeleter(Message* msg) {
delete reinterpret_cast<NDArray*>(msg->aux_handler);
}
/*! /*!
* \brief Message type for DGL distributed training * \brief Message type for DGL distributed training
*/ */
......
...@@ -71,7 +71,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) { ...@@ -71,7 +71,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
return QUEUE_CLOSE; return QUEUE_CLOSE;
} }
Message & old_msg = queue_.front(); Message old_msg = queue_.front();
queue_.pop(); queue_.pop();
msg->data = old_msg.data; msg->data = old_msg.data;
msg->size = old_msg.size; msg->size = old_msg.size;
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#ifndef DGL_GRAPH_NETWORK_MSG_QUEUE_H_ #ifndef DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#define DGL_GRAPH_NETWORK_MSG_QUEUE_H_ #define DGL_GRAPH_NETWORK_MSG_QUEUE_H_
#include <dgl/runtime/ndarray.h>
#include <queue> #include <queue>
#include <set> #include <set>
#include <string> #include <string>
...@@ -54,10 +56,6 @@ struct Message { ...@@ -54,10 +56,6 @@ struct Message {
* \brief message size in bytes * \brief message size in bytes
*/ */
int64_t size; int64_t size;
/*!
* \brief aux_data pointer handler
*/
void* aux_handler;
/*! /*!
* \brief user-defined deallocator, which can be nullptr * \brief user-defined deallocator, which can be nullptr
*/ */
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment