"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8dba1808852e7f5c08f91296006ec254cecdd1b1"
Unverified Commit 7c47d8c9 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Feature] Simplify shared memory graph index (#1381)



* simplify shared memory graph index.

* fix.

* remove edge_dir in SharedMemGraphStore.

* avoid creating shared-mem graph store with from_csr.

* simplify from_csr.

* add comments.

* fix lint.

* remove the test.

* fix compilation error.

* fix a bug.

* fix a bug.
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-150.us-west-2.compute.internal>
parent 1b152bf5
...@@ -11,8 +11,8 @@ class GraphData: ...@@ -11,8 +11,8 @@ class GraphData:
def __init__(self, csr, num_feats, graph_name): def __init__(self, csr, num_feats, graph_name):
num_nodes = csr.shape[0] num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0] num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
self.graph = dgl.graph_index.from_csr(csr.indptr, csr.indices, False, self.graph = dgl.graph_index.from_csr(csr.indptr, csr.indices, False, 'in')
'in', dgl.contrib.graph_store._get_graph_path(graph_name)) self.graph = self.graph.copyto_shared_mem(dgl.contrib.graph_store._get_graph_path(graph_name))
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats)) self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10 self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels, self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
......
...@@ -886,13 +886,7 @@ class ImmutableGraph: public GraphInterface { ...@@ -886,13 +886,7 @@ class ImmutableGraph: public GraphInterface {
static ImmutableGraphPtr CreateFromCSR( static ImmutableGraphPtr CreateFromCSR(
IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir); IdArray indptr, IdArray indices, IdArray edge_ids, const std::string &edge_dir);
static ImmutableGraphPtr CreateFromCSR( static ImmutableGraphPtr CreateFromCSR(const std::string &shared_mem_name);
IdArray indptr, IdArray indices, IdArray edge_ids,
const std::string &edge_dir, const std::string &shared_mem_name);
static ImmutableGraphPtr CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, const std::string &edge_dir);
/*! \brief Create an immutable graph from COO. */ /*! \brief Create an immutable graph from COO. */
static ImmutableGraphPtr CreateFromCOO( static ImmutableGraphPtr CreateFromCOO(
...@@ -918,12 +912,10 @@ class ImmutableGraph: public GraphInterface { ...@@ -918,12 +912,10 @@ class ImmutableGraph: public GraphInterface {
/*! /*!
* \brief Copy data to shared memory. * \brief Copy data to shared memory.
* \param edge_dir the graph of the specific edge direction to be copied.
* \param name The name of the shared memory. * \param name The name of the shared memory.
* \return The graph in the shared memory * \return The graph in the shared memory
*/ */
static ImmutableGraphPtr CopyToSharedMem( static ImmutableGraphPtr CopyToSharedMem(ImmutableGraphPtr g, const std::string &name);
ImmutableGraphPtr g, const std::string &edge_dir, const std::string &name);
/*! /*!
* \brief Convert the graph to use the given number of bits for storage. * \brief Convert the graph to use the given number of bits for storage.
...@@ -952,6 +944,14 @@ class ImmutableGraph: public GraphInterface { ...@@ -952,6 +944,14 @@ class ImmutableGraph: public GraphInterface {
GetOutCSR()->SortCSR(); GetOutCSR()->SortCSR();
} }
bool HasInCSR() const {
return in_csr_ != NULL;
}
bool HasOutCSR() const {
return out_csr_ != NULL;
}
/*! \brief Cast this graph to a heterograph */ /*! \brief Cast this graph to a heterograph */
HeteroGraphPtr AsHeteroGraph() const; HeteroGraphPtr AsHeteroGraph() const;
...@@ -995,6 +995,8 @@ class ImmutableGraph: public GraphInterface { ...@@ -995,6 +995,8 @@ class ImmutableGraph: public GraphInterface {
// The name of shared memory for this graph. // The name of shared memory for this graph.
// If it's empty, the graph isn't stored in shared memory. // If it's empty, the graph isn't stored in shared memory.
std::string shared_mem_name_; std::string shared_mem_name_;
// We serialize the metadata of the graph index here for shared memory.
NDArray serialized_shared_meta_;
}; };
// inline implementations // inline implementations
......
...@@ -13,7 +13,7 @@ from ..base import ALL, is_all, DGLError, dgl_warning ...@@ -13,7 +13,7 @@ from ..base import ALL, is_all, DGLError, dgl_warning
from .. import backend as F from .. import backend as F
from ..graph import DGLGraph from ..graph import DGLGraph
from .. import utils from .. import utils
from ..graph_index import GraphIndex, create_graph_index, from_csr, from_shared_mem_csr_matrix from ..graph_index import GraphIndex, create_graph_index, from_shared_mem_graph_index
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from .._ffi.function import _init_api from .._ffi.function import _init_api
from .. import ndarray as nd from .. import ndarray as nd
...@@ -25,9 +25,6 @@ def _get_ndata_path(graph_name, ndata_name): ...@@ -25,9 +25,6 @@ def _get_ndata_path(graph_name, ndata_name):
def _get_edata_path(graph_name, edata_name): def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name return "/" + graph_name + "_edge_" + edata_name
def _get_edata_path(graph_name, edata_name):
return "/" + graph_name + "_edge_" + edata_name
def _get_graph_path(graph_name): def _get_graph_path(graph_name):
return "/" + graph_name return "/" + graph_name
...@@ -118,21 +115,6 @@ class EdgeDataView(MutableMapping): ...@@ -118,21 +115,6 @@ class EdgeDataView(MutableMapping):
data = self._graph.get_e_repr(self._edges) data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame}) return repr({key : data[key] for key in self._graph._edge_frame})
def _to_csr(graph_data, edge_dir, multigraph=None):
try:
indptr = graph_data.indptr
indices = graph_data.indices
return indptr, indices
except:
if isinstance(graph_data, scipy.sparse.spmatrix):
csr = graph_data.tocsr()
return csr.indptr, csr.indices
else:
idx = create_graph_index(graph_data=graph_data, readonly=True)
transpose = (edge_dir != 'in')
csr = idx.adjacency_matrix_scipy(transpose, 'csr')
return csr.indptr, csr.indices
class Barrier(object): class Barrier(object):
""" A barrier in the KVStore server used for one synchronization. """ A barrier in the KVStore server used for one synchronization.
...@@ -305,8 +287,6 @@ class SharedMemoryStoreServer(object): ...@@ -305,8 +287,6 @@ class SharedMemoryStoreServer(object):
---------- ----------
graph_data : graph data graph_data : graph data
Data to initialize graph. Data to initialize graph.
edge_dir : string
the edge direction for the graph structure ("in" or "out")
graph_name : string graph_name : string
Define the name of the graph, so the client can use the name to access the graph. Define the name of the graph, so the client can use the name to access the graph.
multigraph : bool, optional multigraph : bool, optional
...@@ -317,27 +297,23 @@ class SharedMemoryStoreServer(object): ...@@ -317,27 +297,23 @@ class SharedMemoryStoreServer(object):
port : int port : int
The port that the server listens to. The port that the server listens to.
""" """
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port): def __init__(self, graph_data, graph_name, multigraph, num_workers, port):
self.server = None self.server = None
if multigraph is not None: if multigraph is not None:
dgl_warning("multigraph will be deprecated." \ dgl_warning("multigraph will be deprecated." \
"DGL will treat all graphs as multigraph in the future.") "DGL will treat all graphs as multigraph in the future.")
if isinstance(graph_data, GraphIndex): if isinstance(graph_data, GraphIndex):
graph_data = graph_data.copyto_shared_mem(edge_dir, _get_graph_path(graph_name)) graph_data = graph_data.copyto_shared_mem(_get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, readonly=True)
elif isinstance(graph_data, DGLGraph): elif isinstance(graph_data, DGLGraph):
graph_data = graph_data._graph.copyto_shared_mem(edge_dir, _get_graph_path(graph_name)) graph_data = graph_data._graph.copyto_shared_mem(_get_graph_path(graph_name))
self._graph = DGLGraph(graph_data, readonly=True)
else: else:
indptr, indices = _to_csr(graph_data, edge_dir) graph_data = create_graph_index(graph_data, readonly=True)
graph_idx = from_csr(utils.toindex(indptr), utils.toindex(indices), graph_data = graph_data.copyto_shared_mem(_get_graph_path(graph_name))
edge_dir, _get_graph_path(graph_name)) self._graph = DGLGraph(graph_data, readonly=True)
self._graph = DGLGraph(graph_idx, readonly=True)
self._num_workers = num_workers self._num_workers = num_workers
self._graph_name = graph_name self._graph_name = graph_name
self._edge_dir = edge_dir
self._registered_nworkers = 0 self._registered_nworkers = 0
self._barrier = BarrierManager(num_workers) self._barrier = BarrierManager(num_workers)
...@@ -358,8 +334,7 @@ class SharedMemoryStoreServer(object): ...@@ -358,8 +334,7 @@ class SharedMemoryStoreServer(object):
assert graph_name == self._graph_name assert graph_name == self._graph_name
# if the integers are larger than 2^31, xmlrpc can't handle them. # if the integers are larger than 2^31, xmlrpc can't handle them.
# we convert them to strings to send them to clients. # we convert them to strings to send them to clients.
return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges()), \ return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges())
True, edge_dir
# RPC command: initialize node embedding in the server. # RPC command: initialize node embedding in the server.
def init_ndata(init, ndata_name, shape, dtype): def init_ndata(init, ndata_name, shape, dtype):
...@@ -560,11 +535,10 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -560,11 +535,10 @@ class SharedMemoryDGLGraph(BaseGraphStore):
self._worker_id, self._num_workers = self.proxy.register(graph_name) self._worker_id, self._num_workers = self.proxy.register(graph_name)
if self._worker_id < 0: if self._worker_id < 0:
raise Exception('fail to get graph ' + graph_name + ' from the graph store') raise Exception('fail to get graph ' + graph_name + ' from the graph store')
num_nodes, num_edges, _, edge_dir = self.proxy.get_graph_info(graph_name) num_nodes, num_edges = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = int(num_nodes), int(num_edges) num_nodes, num_edges = int(num_nodes), int(num_edges)
graph_idx = from_shared_mem_csr_matrix(_get_graph_path(graph_name), graph_idx = from_shared_mem_graph_index(_get_graph_path(graph_name))
num_nodes, num_edges, edge_dir)
super(SharedMemoryDGLGraph, self).__init__(graph_idx) super(SharedMemoryDGLGraph, self).__init__(graph_idx)
self._init_manager = InitializerManager() self._init_manager = InitializerManager()
...@@ -1059,7 +1033,7 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -1059,7 +1033,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
def create_graph_store_server(graph_data, graph_name, store_type, num_workers, def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
multigraph=None, edge_dir='in', port=8000): multigraph=None, port=8000):
"""Create the graph store server. """Create the graph store server.
The server loads graph structure and node embeddings and edge embeddings. The server loads graph structure and node embeddings and edge embeddings.
...@@ -1092,9 +1066,6 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers, ...@@ -1092,9 +1066,6 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
multigraph : bool, optional multigraph : bool, optional
Deprecated (Will be deleted in the future). Deprecated (Will be deleted in the future).
Whether the graph would be a multigraph (default: True) Whether the graph would be a multigraph (default: True)
edge_dir : string
the edge direction for the graph structure. The supported option is
"in" and "out".
port : int port : int
The port that the server listens to. The port that the server listens to.
...@@ -1106,7 +1077,7 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers, ...@@ -1106,7 +1077,7 @@ def create_graph_store_server(graph_data, graph_name, store_type, num_workers,
if multigraph is not None: if multigraph is not None:
dgl_warning("multigraph is deprecated." \ dgl_warning("multigraph is deprecated." \
"DGL treat all graphs as multigraph by default.") "DGL treat all graphs as multigraph by default.")
return SharedMemoryStoreServer(graph_data, edge_dir, graph_name, None, return SharedMemoryStoreServer(graph_data, graph_name, None,
num_workers, port) num_workers, port)
def create_graph_from_store(graph_name, store_type, port=8000): def create_graph_from_store(graph_name, store_type, port=8000):
......
...@@ -881,15 +881,13 @@ class GraphIndex(ObjectBase): ...@@ -881,15 +881,13 @@ class GraphIndex(ObjectBase):
""" """
return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id) return _CAPI_DGLImmutableGraphCopyTo(self, ctx.device_type, ctx.device_id)
def copyto_shared_mem(self, edge_dir, shared_mem_name): def copyto_shared_mem(self, shared_mem_name):
"""Copy this immutable graph index to shared memory. """Copy this immutable graph index to shared memory.
NOTE: this method only works for immutable graph index NOTE: this method only works for immutable graph index
Parameters Parameters
---------- ----------
edge_dir : string
Indicate which CSR should copy ("in", "out", "both").
shared_mem_name : string shared_mem_name : string
The name of the shared memory. The name of the shared memory.
...@@ -898,7 +896,7 @@ class GraphIndex(ObjectBase): ...@@ -898,7 +896,7 @@ class GraphIndex(ObjectBase):
GraphIndex GraphIndex
The graph index on the given device context. The graph index on the given device context.
""" """
return _CAPI_DGLImmutableGraphCopyToSharedMem(self, edge_dir, shared_mem_name) return _CAPI_DGLImmutableGraphCopyToSharedMem(self, shared_mem_name)
def nbits(self): def nbits(self):
"""Return the number of integer bits used in the storage (32 or 64). """Return the number of integer bits used in the storage (32 or 64).
...@@ -1017,8 +1015,7 @@ def from_coo(num_nodes, src, dst, readonly): ...@@ -1017,8 +1015,7 @@ def from_coo(num_nodes, src, dst, readonly):
gidx.add_edges(src, dst) gidx.add_edges(src, dst)
return gidx return gidx
def from_csr(indptr, indices, def from_csr(indptr, indices, direction):
direction, shared_mem_name=""):
"""Load a graph from CSR arrays. """Load a graph from CSR arrays.
Parameters Parameters
...@@ -1029,38 +1026,24 @@ def from_csr(indptr, indices, ...@@ -1029,38 +1026,24 @@ def from_csr(indptr, indices,
column index array in the CSR format column index array in the CSR format
direction : str direction : str
the edge direction. Either "in" or "out". the edge direction. Either "in" or "out".
shared_mem_name : str
the name of shared memory
""" """
indptr = utils.toindex(indptr) indptr = utils.toindex(indptr)
indices = utils.toindex(indices) indices = utils.toindex(indices)
gidx = _CAPI_DGLGraphCSRCreate( gidx = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(), indptr.todgltensor(),
indices.todgltensor(), indices.todgltensor(),
shared_mem_name,
direction) direction)
return gidx return gidx
def from_shared_mem_csr_matrix(shared_mem_name, def from_shared_mem_graph_index(shared_mem_name):
num_nodes, num_edges, edge_dir): """Load a graph index from the shared memory.
"""Load a graph from the shared memory in the CSR format.
Parameters Parameters
---------- ----------
shared_mem_name : string shared_mem_name : string
the name of shared memory the name of shared memory
num_nodes : int
the number of nodes
num_edges : int
the number of edges
edge_dir : string
the edge direction. The supported option is "in" and "out".
""" """
gidx = _CAPI_DGLGraphCSRCreateMMap( return _CAPI_DGLGraphCSRCreateMMap(shared_mem_name)
shared_mem_name,
int(num_nodes), int(num_edges),
edge_dir)
return gidx
def from_networkx(nx_graph, readonly): def from_networkx(nx_graph, readonly):
"""Convert from networkx graph. """Convert from networkx graph.
......
...@@ -44,30 +44,20 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") ...@@ -44,30 +44,20 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray indptr = args[0]; const IdArray indptr = args[0];
const IdArray indices = args[1]; const IdArray indices = args[1];
const std::string shared_mem_name = args[2]; const std::string edge_dir = args[2];
const std::string edge_dir = args[3];
IdArray edge_ids = IdArray::Empty({indices->shape[0]}, IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0}); DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data); int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (size_t i = 0; i < edge_ids->shape[0]; i++) for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i; edge_data[i] = i;
if (shared_mem_name.empty()) { *rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
*rv = GraphRef(ImmutableGraph::CreateFromCSR(indptr, indices, edge_ids, edge_dir));
} else {
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
indptr, indices, edge_ids, edge_dir, shared_mem_name));
}
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreateMMap")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string shared_mem_name = args[0]; const std::string shared_mem_name = args[0];
const int64_t num_vertices = args[1]; *rv = GraphRef(ImmutableGraph::CreateFromCSR(shared_mem_name));
const int64_t num_edges = args[2];
const std::string edge_dir = args[3];
*rv = GraphRef(ImmutableGraph::CreateFromCSR(
shared_mem_name, num_vertices, num_edges, edge_dir));
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphAddVertices")
......
...@@ -27,6 +27,55 @@ inline std::string GetSharedMemName(const std::string &name, const std::string & ...@@ -27,6 +27,55 @@ inline std::string GetSharedMemName(const std::string &name, const std::string &
return name + "_" + edge_dir; return name + "_" + edge_dir;
} }
/*
* The metadata of a graph index that are needed for shared-memory graph.
*/
struct GraphIndexMetadata {
int64_t num_nodes;
int64_t num_edges;
bool has_in_csr;
bool has_out_csr;
bool has_coo;
};
/*
* Serialize the metadata of a graph index and place it in a shared-memory tensor.
* In this way, another process can reconstruct a GraphIndex from a shared-memory tensor.
*/
NDArray SerializeMetadata(ImmutableGraphPtr gidx, const std::string &name) {
#ifndef _WIN32
GraphIndexMetadata meta;
meta.num_nodes = gidx->NumVertices();
meta.num_edges = gidx->NumEdges();
meta.has_in_csr = gidx->HasInCSR();
meta.has_out_csr = gidx->HasOutCSR();
meta.has_coo = false;
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DLDataType{kDLInt, 8, 1},
DLContext{kDLCPU, 0}, true);
memcpy(meta_arr->data, &meta, sizeof(meta));
return meta_arr;
#else
LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
return NDArray();
#endif // _WIN32
}
/*
* Deserialize the metadata of a graph index.
*/
GraphIndexMetadata DeserializeMetadata(const std::string &name) {
GraphIndexMetadata meta;
#ifndef _WIN32
NDArray meta_arr = NDArray::EmptyShared(name, {sizeof(meta)}, DLDataType{kDLInt, 8, 1},
DLContext{kDLCPU, 0}, false);
memcpy(&meta, meta_arr->data, sizeof(meta));
#else
LOG(FATAL) << "CSR graph doesn't support shared memory in Windows yet";
#endif // _WIN32
return meta;
}
std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory( std::tuple<IdArray, IdArray, IdArray> MapFromSharedMemory(
const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) { const std::string &shared_mem_name, int64_t num_verts, int64_t num_edges, bool is_create) {
#ifndef _WIN32 #ifndef _WIN32
...@@ -467,34 +516,16 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR( ...@@ -467,34 +516,16 @@ ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
} }
} }
ImmutableGraphPtr ImmutableGraph::CreateFromCSR( ImmutableGraphPtr ImmutableGraph::CreateFromCSR(const std::string &name) {
IdArray indptr, IdArray indices, IdArray edge_ids, GraphIndexMetadata meta = DeserializeMetadata(GetSharedMemName(name, "meta"));
const std::string &edge_dir, CSRPtr in_csr, out_csr;
const std::string &shared_mem_name) { if (meta.has_in_csr) {
CSRPtr csr(new CSR(indptr, indices, edge_ids, in_csr = CSRPtr(new CSR(GetSharedMemName(name, "in"), meta.num_nodes, meta.num_edges));
GetSharedMemName(shared_mem_name, edge_dir)));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
} }
} if (meta.has_out_csr) {
out_csr = CSRPtr(new CSR(GetSharedMemName(name, "out"), meta.num_nodes, meta.num_edges));
ImmutableGraphPtr ImmutableGraph::CreateFromCSR(
const std::string &shared_mem_name, size_t num_vertices,
size_t num_edges, const std::string &edge_dir) {
CSRPtr csr(new CSR(GetSharedMemName(shared_mem_name, edge_dir), num_vertices, num_edges));
if (edge_dir == "in") {
return ImmutableGraphPtr(new ImmutableGraph(csr, nullptr, shared_mem_name));
} else if (edge_dir == "out") {
return ImmutableGraphPtr(new ImmutableGraph(nullptr, csr, shared_mem_name));
} else {
LOG(FATAL) << "Unknown edge direction: " << edge_dir;
return ImmutableGraphPtr();
} }
return ImmutableGraphPtr(new ImmutableGraph(in_csr, out_csr, name));
} }
ImmutableGraphPtr ImmutableGraph::CreateFromCOO( ImmutableGraphPtr ImmutableGraph::CreateFromCOO(
...@@ -527,15 +558,17 @@ ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& c ...@@ -527,15 +558,17 @@ ImmutableGraphPtr ImmutableGraph::CopyTo(ImmutableGraphPtr g, const DLContext& c
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr)); return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr));
} }
ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g, ImmutableGraphPtr ImmutableGraph::CopyToSharedMem(ImmutableGraphPtr g, const std::string &name) {
const std::string &edge_dir, const std::string &name) {
CSRPtr new_incsr, new_outcsr; CSRPtr new_incsr, new_outcsr;
std::string shared_mem_name = GetSharedMemName(name, edge_dir); std::string shared_mem_name = GetSharedMemName(name, "in");
if (edge_dir == std::string("in")) new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
new_incsr = CSRPtr(new CSR(g->GetInCSR()->CopyToSharedMem(shared_mem_name)));
else if (edge_dir == std::string("out")) shared_mem_name = GetSharedMemName(name, "out");
new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name))); new_outcsr = CSRPtr(new CSR(g->GetOutCSR()->CopyToSharedMem(shared_mem_name)));
return ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
auto new_g = ImmutableGraphPtr(new ImmutableGraph(new_incsr, new_outcsr, name));
new_g->serialized_shared_meta_ = SerializeMetadata(new_g, GetSharedMemName(name, "meta"));
return new_g;
} }
ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) { ImmutableGraphPtr ImmutableGraph::AsNumBits(ImmutableGraphPtr g, uint8_t bits) {
...@@ -622,10 +655,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo") ...@@ -622,10 +655,9 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyTo")
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphCopyToSharedMem")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
GraphRef g = args[0]; GraphRef g = args[0];
std::string edge_dir = args[1]; std::string name = args[1];
std::string name = args[2];
ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr())); ImmutableGraphPtr ig = CHECK_NOTNULL(std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()));
*rv = ImmutableGraph::CopyToSharedMem(ig, edge_dir, name); *rv = ImmutableGraph::CopyToSharedMem(ig, name);
}); });
DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits") DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLImmutableGraphAsNumBits")
......
...@@ -170,20 +170,6 @@ def test_load_csr(): ...@@ -170,20 +170,6 @@ def test_load_csr():
assert np.all(F.asnumpy(src) == coo.row) assert np.all(F.asnumpy(src) == coo.row)
assert np.all(F.asnumpy(dst) == coo.col) assert np.all(F.asnumpy(dst) == coo.col)
# Load CSR to shared memory.
# Shared memory isn't supported in Windows.
if os.name is not 'nt':
idx = dgl.graph_index.from_csr(
utils.toindex(csr.indptr), utils.toindex(csr.indices),
'out', '/test_graph_struct')
assert idx.number_of_nodes() == n
assert idx.number_of_edges() == csr.nnz
src, dst, eid = idx.edges()
src, dst, eid = src.tousertensor(), dst.tousertensor(), eid.tousertensor()
coo = csr.tocoo()
assert np.all(F.asnumpy(src) == coo.row)
assert np.all(F.asnumpy(dst) == coo.col)
def test_edge_ids(): def test_edge_ids():
np.random.seed(0) np.random.seed(0)
csr = (spsp.random(20, 20, density=0.1, format='csr') != 0).astype(np.int64) csr = (spsp.random(20, 20, density=0.1, format='csr') != 0).astype(np.int64)
......
...@@ -47,6 +47,7 @@ def create_graph_store(graph_name): ...@@ -47,6 +47,7 @@ def create_graph_store(graph_name):
def check_init_func(worker_id, graph_name, return_dict): def check_init_func(worker_id, graph_name, return_dict):
np.random.seed(0) np.random.seed(0)
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64) csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
tmp_g = dgl.DGLGraph(csr, readonly=True, multigraph=False)
# Verify the graph structure loaded from the shared memory. # Verify the graph structure loaded from the shared memory.
try: try:
...@@ -55,10 +56,10 @@ def check_init_func(worker_id, graph_name, return_dict): ...@@ -55,10 +56,10 @@ def check_init_func(worker_id, graph_name, return_dict):
return_dict[worker_id] = -1 return_dict[worker_id] = -1
return return
src, dst = g.all_edges() src, dst = g.all_edges(order='srcdst')
coo = csr.tocoo() src1, dst1 = tmp_g.all_edges(order='srcdst')
assert_array_equal(F.asnumpy(dst), coo.row) assert_array_equal(F.asnumpy(dst), F.asnumpy(dst1))
assert_array_equal(F.asnumpy(src), coo.col) assert_array_equal(F.asnumpy(src), F.asnumpy(src1))
feat = F.asnumpy(g.nodes[0].data['feat']) feat = F.asnumpy(g.nodes[0].data['feat'])
assert_array_equal(np.squeeze(feat), np.arange(10, dtype=feat.dtype)) assert_array_equal(np.squeeze(feat), np.arange(10, dtype=feat.dtype))
feat = F.asnumpy(g.edges[0].data['feat']) feat = F.asnumpy(g.edges[0].data['feat'])
...@@ -90,7 +91,7 @@ def server_func(num_workers, graph_name, server_init): ...@@ -90,7 +91,7 @@ def server_func(num_workers, graph_name, server_init):
csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64) csr = (spsp.random(num_nodes, num_nodes, density=0.1, format='csr') != 0).astype(np.int64)
g = dgl.contrib.graph_store.create_graph_store_server(csr, graph_name, "shared_mem", num_workers, g = dgl.contrib.graph_store.create_graph_store_server(csr, graph_name, "shared_mem", num_workers,
False, edge_dir="in", port=rand_port) False, port=rand_port)
assert num_nodes == g._graph.number_of_nodes() assert num_nodes == g._graph.number_of_nodes()
assert num_edges == g._graph.number_of_edges() assert num_edges == g._graph.number_of_edges()
nfeat = np.arange(0, num_nodes * 10).astype('float32').reshape((num_nodes, 10)) nfeat = np.arange(0, num_nodes * 10).astype('float32').reshape((num_nodes, 10))
...@@ -237,8 +238,7 @@ def test_sync_barrier(): ...@@ -237,8 +238,7 @@ def test_sync_barrier():
def create_mem(gidx, cond_v, shared_v): def create_mem(gidx, cond_v, shared_v):
# serialize create_mem before check_mem # serialize create_mem before check_mem
cond_v.acquire() cond_v.acquire()
gidx1 = gidx.copyto_shared_mem("in", "test_graph5") gidx1 = gidx.copyto_shared_mem("test_graph5")
gidx2 = gidx.copyto_shared_mem("out", "test_graph6")
shared_v.value = 1; shared_v.value = 1;
cond_v.notify() cond_v.notify()
cond_v.release() cond_v.release()
...@@ -256,10 +256,7 @@ def check_mem(gidx, cond_v, shared_v): ...@@ -256,10 +256,7 @@ def check_mem(gidx, cond_v, shared_v):
cond_v.wait() cond_v.wait()
cond_v.release() cond_v.release()
gidx1 = dgl.graph_index.from_shared_mem_csr_matrix("test_graph5", gidx.number_of_nodes(), gidx1 = dgl.graph_index.from_shared_mem_graph_index("test_graph5")
gidx.number_of_edges(), "in")
gidx2 = dgl.graph_index.from_shared_mem_csr_matrix("test_graph6", gidx.number_of_nodes(),
gidx.number_of_edges(), "out")
in_csr = gidx.adjacency_matrix_scipy(False, "csr") in_csr = gidx.adjacency_matrix_scipy(False, "csr")
out_csr = gidx.adjacency_matrix_scipy(True, "csr") out_csr = gidx.adjacency_matrix_scipy(True, "csr")
...@@ -270,15 +267,7 @@ def check_mem(gidx, cond_v, shared_v): ...@@ -270,15 +267,7 @@ def check_mem(gidx, cond_v, shared_v):
assert_array_equal(out_csr.indptr, out_csr1.indptr) assert_array_equal(out_csr.indptr, out_csr1.indptr)
assert_array_equal(out_csr.indices, out_csr1.indices) assert_array_equal(out_csr.indices, out_csr1.indices)
in_csr2 = gidx2.adjacency_matrix_scipy(False, "csr") gidx1 = gidx1.copyto_shared_mem("test_graph5")
assert_array_equal(in_csr.indptr, in_csr2.indptr)
assert_array_equal(in_csr.indices, in_csr2.indices)
out_csr2 = gidx2.adjacency_matrix_scipy(True, "csr")
assert_array_equal(out_csr.indptr, out_csr2.indptr)
assert_array_equal(out_csr.indices, out_csr2.indices)
gidx1 = gidx1.copyto_shared_mem("in", "test_graph5")
gidx2 = gidx2.copyto_shared_mem("out", "test_graph6")
#sync for exit #sync for exit
cond_v.acquire() cond_v.acquire()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment