Unverified Commit 94ecb8eb authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] copy graph index to shared memory. (#634)

* copy graph index to shared memory.

* fix.

* fix.

* fix.

* use a diff name for in-csr and out-csr.

* fix lint.

* remove print.

* add test.

* add comments.
parent 16ec2a8b
...@@ -11,10 +11,8 @@ class GraphData: ...@@ -11,10 +11,8 @@ class GraphData:
def __init__(self, csr, num_feats, graph_name): def __init__(self, csr, num_feats, graph_name):
num_nodes = csr.shape[0] num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0] num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64) self.graph = dgl.graph_index.from_csr(csr.indptr, csr.indices, False,
self.graph = dgl.graph_index.from_csr_matrix( 'in', dgl.contrib.graph_store._get_graph_path(graph_name))
dgl.utils.toindex(csr.indptr), dgl.utils.toindex(csr.indices), False,
"in", dgl.contrib.graph_store._get_graph_path(graph_name))
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats)) self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10 self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels, self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
...@@ -69,7 +67,7 @@ def main(args): ...@@ -69,7 +67,7 @@ def main(args):
# create GCN model # create GCN model
print('graph name: ' + graph_name) print('graph name: ' + graph_name)
g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem", g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem",
args.num_workers, False) args.num_workers, False, edge_dir='in')
g.ndata['features'] = features g.ndata['features'] = features
g.ndata['labels'] = labels g.ndata['labels'] = labels
g.ndata['train_mask'] = train_mask g.ndata['train_mask'] = train_mask
......
...@@ -208,6 +208,11 @@ class CSR : public GraphInterface { ...@@ -208,6 +208,11 @@ class CSR : public GraphInterface {
return {indptr_, indices_, edge_ids_}; return {indptr_, indices_, edge_ids_};
} }
/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return !shared_mem_name_.empty();
}
/*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */ /*! \brief Return the reverse of this CSR graph (i.e, a CSC graph) */
CSRPtr Transpose() const; CSRPtr Transpose() const;
...@@ -230,6 +235,13 @@ class CSR : public GraphInterface { ...@@ -230,6 +235,13 @@ class CSR : public GraphInterface {
*/ */
CSR CopyTo(const DLContext& ctx) const; CSR CopyTo(const DLContext& ctx) const;
/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
CSR CopyToSharedMem(const std::string &name) const;
/*! /*!
* \brief Convert the graph to use the given number of bits for storage. * \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64). * \param bits The new number of integer bits (32 or 64).
...@@ -262,6 +274,10 @@ class CSR : public GraphInterface { ...@@ -262,6 +274,10 @@ class CSR : public GraphInterface {
// whether the graph is a multi-graph // whether the graph is a multi-graph
LazyObject<bool> is_multigraph_; LazyObject<bool> is_multigraph_;
// The name of the shared memory to store data.
// If it's empty, data isn't stored in shared memory.
std::string shared_mem_name_;
}; };
class COO : public GraphInterface { class COO : public GraphInterface {
...@@ -478,6 +494,13 @@ class COO : public GraphInterface { ...@@ -478,6 +494,13 @@ class COO : public GraphInterface {
*/ */
COO CopyTo(const DLContext& ctx) const; COO CopyTo(const DLContext& ctx) const;
/*!
* \brief Copy data to shared memory.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
COO CopyToSharedMem(const std::string &name) const;
/*! /*!
* \brief Convert the graph to use the given number of bits for storage. * \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64). * \param bits The new number of integer bits (32 or 64).
...@@ -485,6 +508,11 @@ class COO : public GraphInterface { ...@@ -485,6 +508,11 @@ class COO : public GraphInterface {
*/ */
COO AsNumBits(uint8_t bits) const; COO AsNumBits(uint8_t bits) const;
/*! \brief Indicate whether this uses shared memory. */
bool IsSharedMem() const {
return false;
}
// member getters // member getters
IdArray src() const { return src_; } IdArray src() const { return src_; }
...@@ -512,6 +540,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -512,6 +540,7 @@ class ImmutableGraph: public GraphInterface {
public: public:
/*! \brief Construct an immutable graph from the COO format. */ /*! \brief Construct an immutable graph from the COO format. */
explicit ImmutableGraph(COOPtr coo): coo_(coo) { } explicit ImmutableGraph(COOPtr coo): coo_(coo) { }
/*! /*!
* \brief Construct an immutable graph from the CSR format. * \brief Construct an immutable graph from the CSR format.
* *
...@@ -889,6 +918,9 @@ class ImmutableGraph: public GraphInterface { ...@@ -889,6 +918,9 @@ class ImmutableGraph: public GraphInterface {
if (!in_csr_) { if (!in_csr_) {
if (out_csr_) { if (out_csr_) {
const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose(); const_cast<ImmutableGraph*>(this)->in_csr_ = out_csr_->Transpose();
if (out_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an in-CSR from a shared-memory out CSR. "
<< "It may dramatically increase memory consumption.";
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR(); const_cast<ImmutableGraph*>(this)->in_csr_ = coo_->Transpose()->ToCSR();
...@@ -902,6 +934,9 @@ class ImmutableGraph: public GraphInterface { ...@@ -902,6 +934,9 @@ class ImmutableGraph: public GraphInterface {
if (!out_csr_) { if (!out_csr_) {
if (in_csr_) { if (in_csr_) {
const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose(); const_cast<ImmutableGraph*>(this)->out_csr_ = in_csr_->Transpose();
if (in_csr_->IsSharedMem())
LOG(WARNING) << "We just construct an out-CSR from a shared-memory in CSR. "
<< "It may dramatically increase memory consumption.";
} else { } else {
CHECK(coo_) << "None of CSR, COO exist"; CHECK(coo_) << "None of CSR, COO exist";
const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR(); const_cast<ImmutableGraph*>(this)->out_csr_ = coo_->ToCSR();
...@@ -941,6 +976,14 @@ class ImmutableGraph: public GraphInterface { ...@@ -941,6 +976,14 @@ class ImmutableGraph: public GraphInterface {
*/ */
ImmutableGraph CopyTo(const DLContext& ctx) const; ImmutableGraph CopyTo(const DLContext& ctx) const;
/*!
* \brief Copy data to shared memory.
* \param edge_dir the graph of the specific edge direction to be copied.
* \param name The name of the shared memory.
* \return The graph in the shared memory
*/
ImmutableGraph CopyToSharedMem(const std::string &edge_dir, const std::string &name) const;
/*! /*!
* \brief Convert the graph to use the given number of bits for storage. * \brief Convert the graph to use the given number of bits for storage.
* \param bits The new number of integer bits (32 or 64). * \param bits The new number of integer bits (32 or 64).
...@@ -948,6 +991,77 @@ class ImmutableGraph: public GraphInterface { ...@@ -948,6 +991,77 @@ class ImmutableGraph: public GraphInterface {
*/ */
ImmutableGraph AsNumBits(uint8_t bits) const; ImmutableGraph AsNumBits(uint8_t bits) const;
/*! \brief Create an immutable graph from CSR. */
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(IdArray indptr, IdArray indices, IdArray edge_ids,
bool multigraph, const std::string &edge_dir,
const std::string &shared_mem_name) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, multigraph,
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
static ImmutableGraph CreateFromCSR(const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, bool multigraph,
const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges,
multigraph));
if (edge_dir == "in") {
return ImmutableGraph(csr, nullptr, shared_mem_name);
} else if (edge_dir == "out") {
return ImmutableGraph(nullptr, csr, shared_mem_name);
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraph();
}
}
protected: protected:
/* !\brief internal default constructor */ /* !\brief internal default constructor */
ImmutableGraph() {} ImmutableGraph() {}
...@@ -958,6 +1072,16 @@ class ImmutableGraph: public GraphInterface { ...@@ -958,6 +1072,16 @@ class ImmutableGraph: public GraphInterface {
CHECK(AnyGraph()) << "At least one graph structure should exist."; CHECK(AnyGraph()) << "At least one graph structure should exist.";
} }
ImmutableGraph(CSRPtr in_csr, CSRPtr out_csr, const std::string shared_mem_name)
: in_csr_(in_csr), out_csr_(out_csr) {
CHECK(in_csr_ || out_csr_) << "Both CSR are missing.";
this->shared_mem_name_ = shared_mem_name;
}
static std::string GetSharedMemName(const std::string &name, const std::string &edge_dir) {
return name + "_" + edge_dir;
}
/* !\brief return pointer to any available graph structure */ /* !\brief return pointer to any available graph structure */
GraphPtr AnyGraph() const { GraphPtr AnyGraph() const {
if (in_csr_) { if (in_csr_) {
...@@ -975,6 +1099,10 @@ class ImmutableGraph: public GraphInterface { ...@@ -975,6 +1099,10 @@ class ImmutableGraph: public GraphInterface {
CSRPtr out_csr_; CSRPtr out_csr_;
// Store the edge list indexed by edge id (COO) // Store the edge list indexed by edge id (COO)
COOPtr coo_; COOPtr coo_;
// The name of shared memory for this graph.
// If it's empty, the graph isn't stored in shared memory.
std::string shared_mem_name_;
}; };
// inline implementations // inline implementations
......
...@@ -318,7 +318,11 @@ class SharedMemoryStoreServer(object): ...@@ -318,7 +318,11 @@ class SharedMemoryStoreServer(object):
""" """
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port): def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
self.server = None self.server = None
if isinstance(graph_data, (GraphIndex, DGLGraph)): if isinstance(graph_data, GraphIndex):
graph_data = graph_data.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True)
elif isinstance(graph_data, DGLGraph):
graph_data = graph_data._graph.copyto_shared_mem(edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True) self._graph = DGLGraph(graph_data, multigraph=multigraph, readonly=True)
else: else:
indptr, indices = _to_csr(graph_data, edge_dir, multigraph) indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
......
...@@ -835,6 +835,26 @@ class GraphIndex(object): ...@@ -835,6 +835,26 @@ class GraphIndex(object):
handle = _CAPI_DGLImmutableGraphCopyTo(self._handle, ctx.device_type, ctx.device_id) handle = _CAPI_DGLImmutableGraphCopyTo(self._handle, ctx.device_type, ctx.device_id)
return GraphIndex(handle) return GraphIndex(handle)
def copyto_shared_mem(self, edge_dir, shared_mem_name):
"""Copy this immutable graph index to shared memory.
NOTE: this method only works for immutable graph index
Parameters
----------
edge_dir : string
Indicate which CSR should copy ("in", "out", "both").
shared_mem_name : string
The name of the shared memory.
Returns
-------
GraphIndex
The graph index on the given device context.
"""
handle = _CAPI_DGLImmutableGraphCopyToSharedMem(self._handle, edge_dir, shared_mem_name)
return GraphIndex(handle)
def nbits(self): def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64). """Return the number of integer bits used in the storage (32 or 64).
......
...@@ -158,33 +158,32 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") ...@@ -158,33 +158,32 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
const std::string shared_mem_name = args[2]; const std::string shared_mem_name = args[2];
const int multigraph = args[3]; const int multigraph = args[3];
const std::string edge_dir = args[4]; const std::string edge_dir = args[4];
CSRPtr csr;
IdArray edge_ids = IdArray::Empty({indices->shape[0]}, IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data); int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (size_t i = 0; i < edge_ids->shape[0]; i++) for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i; edge_data[i] = i;
ImmutableGraph *g = nullptr;
if (shared_mem_name.empty()) { if (shared_mem_name.empty()) {
if (multigraph == kBoolUnknown) { if (multigraph == kBoolUnknown) {
csr.reset(new CSR(indptr, indices, edge_ids)); g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
edge_dir));
} else { } else {
csr.reset(new CSR(indptr, indices, edge_ids, multigraph)); g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir));
} }
} else { } else {
if (multigraph == kBoolUnknown) { if (multigraph == kBoolUnknown) {
csr.reset(new CSR(indptr, indices, edge_ids, shared_mem_name)); g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
edge_dir, shared_mem_name));
} else { } else {
csr.reset(new CSR(indptr, indices, edge_ids, multigraph, shared_mem_name)); g = new ImmutableGraph(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids,
multigraph, edge_dir,
shared_mem_name));
} }
} }
*rv = g;
GraphHandle ghandle;
if (edge_dir == "in")
ghandle = new ImmutableGraph(csr, nullptr);
else
ghandle = new ImmutableGraph(nullptr, csr);
*rv = ghandle;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
...@@ -195,12 +194,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap") ...@@ -195,12 +194,8 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
const bool multigraph = args[3]; const bool multigraph = args[3];
const std::string edge_dir = args[4]; const std::string edge_dir = args[4];
// TODO(minjie): how to know multigraph // TODO(minjie): how to know multigraph
CSRPtr csr(new CSR(shared_mem_name, num_vertices, num_edges, multigraph)); GraphHandle ghandle = new ImmutableGraph(ImmutableGraph::CreateFromCSR(
GraphHandle ghandle; shared_mem_name, num_vertices, num_edges, multigraph, edge_dir));
if (edge_dir == "in")
ghandle = new ImmutableGraph(csr, nullptr);
else
ghandle = new ImmutableGraph(nullptr, csr);
*rv = ghandle; *rv = ghandle;
}); });
...@@ -546,6 +541,18 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo") ...@@ -546,6 +541,18 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
*rv = newhandle; *rv = newhandle;
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0];
std::string edge_dir = args[1];
std::string name = args[2];
const GraphInterface *ptr = static_cast<GraphInterface *>(ghandle);
const ImmutableGraph *ig = dynamic_cast<const ImmutableGraph*>(ptr);
CHECK(ig) << "Invalid argument: must be an immutable graph object.";
GraphHandle newhandle = new ImmutableGraph(ig->CopyToSharedMem(edge_dir, name));
*rv = newhandle;
});
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphNumBits")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphHandle ghandle = args[0]; GraphHandle ghandle = args[0];
......
...@@ -68,12 +68,12 @@ struct PairHash { ...@@ -68,12 +68,12 @@ struct PairHash {
}; };
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges) { const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
#ifndef _WIN32 #ifndef _WIN32
const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t); const int64_t file_size = (num_verts + 1 + num_edges * 2) * sizeof(dgl_id_t);
IdArray sm_array = IdArray::EmptyShared( IdArray sm_array = IdArray::EmptyShared(
shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, true); shared_mem_name, {file_size}, DLDataType{kDLInt, 8, 1}, DLContext{kDLCPU, 0}, is_create);
// Create views from the shared memory array. Note that we don't need to save // Create views from the shared memory array. Note that we don't need to save
// the sm_array because the refcount is maintained by the view arrays. // the sm_array because the refcount is maintained by the view arrays.
IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1}); IdArray indptr = sm_array.CreateView({num_verts + 1}, DLDataType{kDLInt, 64, 1});
...@@ -111,7 +111,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids) ...@@ -111,7 +111,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids)
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: indptr_(indptr), indices_(indices), edge_ids_(edge_ids), is_multigraph_(is_multigraph) { : indptr_(indptr), indices_(indices), edge_ids_(edge_ids),
is_multigraph_(is_multigraph) {
CHECK(IsValidIdArray(indptr)); CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(IsValidIdArray(edge_ids));
...@@ -119,7 +120,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) ...@@ -119,7 +120,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &shared_mem_name) { const std::string &shared_mem_name): shared_mem_name_(shared_mem_name) {
CHECK(IsValidIdArray(indptr)); CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(IsValidIdArray(edge_ids));
...@@ -127,7 +128,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -127,7 +128,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
const int64_t num_verts = indptr->shape[0] - 1; const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0]; const int64_t num_edges = indices->shape[0];
std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory( std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges); shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays // copy the given data into the shared memory arrays
indptr_.CopyFrom(indptr); indptr_.CopyFrom(indptr);
indices_.CopyFrom(indices); indices_.CopyFrom(indices);
...@@ -135,7 +136,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, ...@@ -135,7 +136,8 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids,
} }
CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const std::string &shared_mem_name): is_multigraph_(is_multigraph) { const std::string &shared_mem_name): is_multigraph_(is_multigraph),
shared_mem_name_(shared_mem_name) {
CHECK(IsValidIdArray(indptr)); CHECK(IsValidIdArray(indptr));
CHECK(IsValidIdArray(indices)); CHECK(IsValidIdArray(indices));
CHECK(IsValidIdArray(edge_ids)); CHECK(IsValidIdArray(edge_ids));
...@@ -143,7 +145,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, ...@@ -143,7 +145,7 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
const int64_t num_verts = indptr->shape[0] - 1; const int64_t num_verts = indptr->shape[0] - 1;
const int64_t num_edges = indices->shape[0]; const int64_t num_edges = indices->shape[0];
std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory( std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges); shared_mem_name, num_verts, num_edges, true);
// copy the given data into the shared memory arrays // copy the given data into the shared memory arrays
indptr_.CopyFrom(indptr); indptr_.CopyFrom(indptr);
indices_.CopyFrom(indices); indices_.CopyFrom(indices);
...@@ -152,9 +154,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph, ...@@ -152,9 +154,9 @@ CSR::CSR(IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph,
CSR::CSR(const std::string &shared_mem_name, CSR::CSR(const std::string &shared_mem_name,
int64_t num_verts, int64_t num_edges, bool is_multigraph) int64_t num_verts, int64_t num_edges, bool is_multigraph)
: is_multigraph_(is_multigraph) { : is_multigraph_(is_multigraph), shared_mem_name_(shared_mem_name) {
std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory( std::tie(indptr_, indices_, edge_ids_) = MapFromSharedMemory(
shared_mem_name, num_verts, num_edges); shared_mem_name, num_verts, num_edges, false);
} }
bool CSR::IsMultigraph() const { bool CSR::IsMultigraph() const {
...@@ -469,6 +471,15 @@ CSR CSR::CopyTo(const DLContext& ctx) const { ...@@ -469,6 +471,15 @@ CSR CSR::CopyTo(const DLContext& ctx) const {
} }
} }
CSR CSR::CopyToSharedMem(const std::string &name) const {
if (IsSharedMem()) {
CHECK(name == shared_mem_name_);
return *this;
} else {
return CSR(indptr_, indices_, edge_ids_, name);
}
}
CSR CSR::AsNumBits(uint8_t bits) const { CSR CSR::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) { if (NumBits() == bits) {
return *this; return *this;
...@@ -664,6 +675,10 @@ COO COO::CopyTo(const DLContext& ctx) const { ...@@ -664,6 +675,10 @@ COO COO::CopyTo(const DLContext& ctx) const {
} }
} }
COO COO::CopyToSharedMem(const std::string &name) const {
LOG(FATAL) << "COO doesn't supprt shared memory yet";
}
COO COO::AsNumBits(uint8_t bits) const { COO COO::AsNumBits(uint8_t bits) const {
if (NumBits() == bits) { if (NumBits() == bits) {
return *this; return *this;
...@@ -761,7 +776,18 @@ ImmutableGraph ImmutableGraph::CopyTo(const DLContext& ctx) const { ...@@ -761,7 +776,18 @@ ImmutableGraph ImmutableGraph::CopyTo(const DLContext& ctx) const {
// be fixed later. // be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->CopyTo(ctx))); CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->CopyTo(ctx)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyTo(ctx))); CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyTo(ctx)));
return ImmutableGraph(new_incsr, new_outcsr, nullptr); return ImmutableGraph(new_incsr, new_outcsr);
}
ImmutableGraph ImmutableGraph::CopyToSharedMem(const std::string &edge_dir,
const std::string &name) const {
CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, edge_dir);
if (edge_dir == "in")
new_incsr = CSRPtr(new CSR(GetInCSR()->CopyToSharedMem(shared_mem_name)));
else if (edge_dir == "out")
new_outcsr = CSRPtr(new CSR(GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraph(new_incsr, new_outcsr, name);
} }
ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const { ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const {
...@@ -774,7 +800,7 @@ ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const { ...@@ -774,7 +800,7 @@ ImmutableGraph ImmutableGraph::AsNumBits(uint8_t bits) const {
// be fixed later. // be fixed later.
CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->AsNumBits(bits))); CSRPtr new_incsr = CSRPtr(new CSR(GetInCSR()->AsNumBits(bits)));
CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->AsNumBits(bits))); CSRPtr new_outcsr = CSRPtr(new CSR(GetOutCSR()->AsNumBits(bits)));
return ImmutableGraph(new_incsr, new_outcsr, nullptr); return ImmutableGraph(new_incsr, new_outcsr);
} }
} }
......
...@@ -220,7 +220,49 @@ def test_sync_barrier(): ...@@ -220,7 +220,49 @@ def test_sync_barrier():
for worker_id in return_dict.keys(): for worker_id in return_dict.keys():
assert return_dict[worker_id] == 0, "worker %d fails" % worker_id assert return_dict[worker_id] == 0, "worker %d fails" % worker_id
def create_mem(gidx):
gidx1 = gidx.copyto_shared_mem("in", "test_graph5")
gidx2 = gidx.copyto_shared_mem("out", "test_graph6")
time.sleep(30)
def check_mem(gidx):
time.sleep(10)
gidx1 = dgl.graph_index.from_shared_mem_csr_matrix("test_graph5", gidx.number_of_nodes(),
gidx.number_of_edges(), "in", False)
gidx2 = dgl.graph_index.from_shared_mem_csr_matrix("test_graph6", gidx.number_of_nodes(),
gidx.number_of_edges(), "out", False)
in_csr = gidx.adjacency_matrix_scipy(False, "csr")
out_csr = gidx.adjacency_matrix_scipy(True, "csr")
in_csr1 = gidx1.adjacency_matrix_scipy(False, "csr")
assert_array_equal(in_csr.indptr, in_csr1.indptr)
assert_array_equal(in_csr.indices, in_csr1.indices)
out_csr1 = gidx1.adjacency_matrix_scipy(True, "csr")
assert_array_equal(out_csr.indptr, out_csr1.indptr)
assert_array_equal(out_csr.indices, out_csr1.indices)
in_csr2 = gidx2.adjacency_matrix_scipy(False, "csr")
assert_array_equal(in_csr.indptr, in_csr2.indptr)
assert_array_equal(in_csr.indices, in_csr2.indices)
out_csr2 = gidx2.adjacency_matrix_scipy(True, "csr")
assert_array_equal(out_csr.indptr, out_csr2.indptr)
assert_array_equal(out_csr.indices, out_csr2.indices)
gidx1 = gidx1.copyto_shared_mem("in", "test_graph5")
gidx2 = gidx2.copyto_shared_mem("out", "test_graph6")
def test_copy_shared_mem():
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
gidx = dgl.graph_index.create_graph_index(csr, False, True)
p1 = Process(target=create_mem, args=(gidx,))
p2 = Process(target=check_mem, args=(gidx,))
p1.start()
p2.start()
p1.join()
p2.join()
if __name__ == '__main__': if __name__ == '__main__':
test_copy_shared_mem()
test_init() test_init()
test_sync_barrier() test_sync_barrier()
test_compute() test_compute()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment