"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f9c0217d271ddef788dca9a491b03f5dbfd86e41"
Unverified Commit 8f378d90 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] fix bugs for running GCN on giant graphs. (#561)

* load mxnet csr.

* enable load large csr.

* fix

* fix.

* fix int overflow.

* fix test.
parent 14af8402
import os
import argparse, time, math import argparse, time, math
import numpy as np import numpy as np
from scipy import sparse as spsp from scipy import sparse as spsp
...@@ -7,14 +8,14 @@ from dgl import DGLGraph ...@@ -7,14 +8,14 @@ from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
class GraphData: class GraphData:
def __init__(self, csr, num_feats): def __init__(self, csr, num_feats, graph_name):
num_nodes = csr.shape[0] num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0] num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64) edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64)
csr = spsp.csr_matrix((edge_ids, csr.indices.asnumpy(), csr.indptr.asnumpy()),
shape=csr.shape, dtype=np.int64)
self.graph = dgl.graph_index.GraphIndex(multigraph=False, readonly=True) self.graph = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
self.graph.from_csr_matrix(csr.indptr, csr.indices, "in") self.graph.from_csr_matrix(dgl.utils.toindex(csr.indptr),
dgl.utils.toindex(csr.indices), "in",
dgl.contrib.graph_store._get_graph_path(graph_name))
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats)) self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10 self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels, self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
...@@ -31,9 +32,9 @@ def main(args): ...@@ -31,9 +32,9 @@ def main(args):
if args.graph_file != '': if args.graph_file != '':
csr = mx.nd.load(args.graph_file)[0] csr = mx.nd.load(args.graph_file)[0]
n_edges = csr.shape[0] n_edges = csr.shape[0]
data = GraphData(csr, args.num_feats) graph_name = os.path.basename(args.graph_file)
data = GraphData(csr, args.num_feats, graph_name)
csr = None csr = None
graph_name = args.graph_file
else: else:
data = load_data(args) data = load_data(args)
n_edges = data.graph.number_of_edges() n_edges = data.graph.number_of_edges()
...@@ -67,6 +68,7 @@ def main(args): ...@@ -67,6 +68,7 @@ def main(args):
n_test_samples)) n_test_samples))
# create GCN model # create GCN model
print('graph name: ' + graph_name)
g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem", g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem",
args.num_workers, False) args.num_workers, False)
g.ndata['features'] = features g.ndata['features'] = features
......
...@@ -305,9 +305,13 @@ class SharedMemoryStoreServer(object): ...@@ -305,9 +305,13 @@ class SharedMemoryStoreServer(object):
""" """
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port): def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
self.server = None self.server = None
graph_idx = GraphIndex(multigraph=multigraph, readonly=True) if isinstance(graph_data, GraphIndex):
indptr, indices = _to_csr(graph_data, edge_dir, multigraph) graph_idx = graph_data
graph_idx.from_csr_matrix(indptr, indices, edge_dir, _get_graph_path(graph_name)) else:
graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
graph_idx.from_csr_matrix(utils.toindex(indptr), utils.toindex(indices),
edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_idx, multigraph=multigraph, readonly=True) self._graph = DGLGraph(graph_idx, multigraph=multigraph, readonly=True)
self._num_workers = num_workers self._num_workers = num_workers
...@@ -331,7 +335,9 @@ class SharedMemoryStoreServer(object): ...@@ -331,7 +335,9 @@ class SharedMemoryStoreServer(object):
# RPC command: get the graph information from the graph store server. # RPC command: get the graph information from the graph store server.
def get_graph_info(graph_name): def get_graph_info(graph_name):
assert graph_name == self._graph_name assert graph_name == self._graph_name
return self._graph.number_of_nodes(), self._graph.number_of_edges(), \ # if the integers are larger than 2^31, xmlrpc can't handle them.
# we convert them to strings to send them to clients.
return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges()), \
self._graph.is_multigraph, edge_dir self._graph.is_multigraph, edge_dir
# RPC command: initialize node embedding in the server. # RPC command: initialize node embedding in the server.
...@@ -532,6 +538,7 @@ class SharedMemoryDGLGraph(BaseGraphStore): ...@@ -532,6 +538,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
if self._worker_id < 0: if self._worker_id < 0:
raise Exception('fail to get graph ' + graph_name + ' from the graph store') raise Exception('fail to get graph ' + graph_name + ' from the graph store')
num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info(graph_name) num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = int(num_nodes), int(num_edges)
graph_idx = GraphIndex(multigraph=multigraph, readonly=True) graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir) graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir)
......
...@@ -812,9 +812,9 @@ class GraphIndex(object): ...@@ -812,9 +812,9 @@ class GraphIndex(object):
Parameters Parameters
---------- ----------
indptr : a 1D tensor indptr : utils.Index
index pointer in the CSR format index pointer in the CSR format
indices : a 1D tensor indices : utils.Index
column index array in the CSR format column index array in the CSR format
edge_dir : string edge_dir : string
the edge direction. The supported option is "in" and "out". the edge direction. The supported option is "in" and "out".
...@@ -822,13 +822,9 @@ class GraphIndex(object): ...@@ -822,13 +822,9 @@ class GraphIndex(object):
the name of shared memory the name of shared memory
""" """
assert self.is_readonly() assert self.is_readonly()
indptr = utils.toindex(indptr)
indices = utils.toindex(indices)
edge_ids = utils.toindex(F.arange(0, len(indices)))
self._handle = _CAPI_DGLGraphCSRCreate( self._handle = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(), indptr.todgltensor(),
indices.todgltensor(), indices.todgltensor(),
edge_ids.todgltensor(),
shared_mem_name, shared_mem_name,
self._multigraph, self._multigraph,
edge_dir) edge_dir)
......
...@@ -152,11 +152,16 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate") ...@@ -152,11 +152,16 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray indptr = args[0]; const IdArray indptr = args[0];
const IdArray indices = args[1]; const IdArray indices = args[1];
const IdArray edge_ids = args[2]; const std::string shared_mem_name = args[2];
const std::string shared_mem_name = args[3]; const bool multigraph = static_cast<bool>(args[3]);
const bool multigraph = static_cast<bool>(args[4]); const std::string edge_dir = args[4];
const std::string edge_dir = args[5];
CSRPtr csr; CSRPtr csr;
IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i;
if (shared_mem_name.empty()) if (shared_mem_name.empty())
// TODO(minjie): The array copy here is unnecessary and adds extra overhead. // TODO(minjie): The array copy here is unnecessary and adds extra overhead.
// However, with MXNet backend, the memory would be corrupted if we directly // However, with MXNet backend, the memory would be corrupted if we directly
......
...@@ -161,7 +161,7 @@ def test_load_csr(): ...@@ -161,7 +161,7 @@ def test_load_csr():
# Load CSR normally. # Load CSR normally.
idx = dgl.graph_index.GraphIndex(multigraph=False, readonly=True) idx = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
idx.from_csr_matrix(csr.indptr, csr.indices, 'out') idx.from_csr_matrix(utils.toindex(csr.indptr), utils.toindex(csr.indices), 'out')
assert idx.number_of_nodes() == n assert idx.number_of_nodes() == n
assert idx.number_of_edges() == csr.nnz assert idx.number_of_edges() == csr.nnz
src, dst, eid = idx.edges() src, dst, eid = idx.edges()
...@@ -174,7 +174,8 @@ def test_load_csr(): ...@@ -174,7 +174,8 @@ def test_load_csr():
# Shared memory isn't supported in Windows. # Shared memory isn't supported in Windows.
if os.name is not 'nt': if os.name is not 'nt':
idx = dgl.graph_index.GraphIndex(multigraph=False, readonly=True) idx = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
idx.from_csr_matrix(csr.indptr, csr.indices, 'out', '/test_graph_struct') idx.from_csr_matrix(utils.toindex(csr.indptr), utils.toindex(csr.indices),
'out', '/test_graph_struct')
assert idx.number_of_nodes() == n assert idx.number_of_nodes() == n
assert idx.number_of_edges() == csr.nnz assert idx.number_of_edges() == csr.nnz
src, dst, eid = idx.edges() src, dst, eid = idx.edges()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment