Unverified Commit 8f378d90 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] fix bugs for running GCN on giant graphs. (#561)

* load mxnet csr.

* enable load large csr.

* fix

* fix.

* fix int overflow.

* fix test.
parent 14af8402
import os
import argparse, time, math
import numpy as np
from scipy import sparse as spsp
......@@ -7,14 +8,14 @@ from dgl import DGLGraph
from dgl.data import register_data_args, load_data
class GraphData:
def __init__(self, csr, num_feats):
def __init__(self, csr, num_feats, graph_name):
num_nodes = csr.shape[0]
num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64)
csr = spsp.csr_matrix((edge_ids, csr.indices.asnumpy(), csr.indptr.asnumpy()),
shape=csr.shape, dtype=np.int64)
self.graph = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
self.graph.from_csr_matrix(csr.indptr, csr.indices, "in")
self.graph.from_csr_matrix(dgl.utils.toindex(csr.indptr),
dgl.utils.toindex(csr.indices), "in",
dgl.contrib.graph_store._get_graph_path(graph_name))
self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
self.num_labels = 10
self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
......@@ -31,9 +32,9 @@ def main(args):
if args.graph_file != '':
csr = mx.nd.load(args.graph_file)[0]
n_edges = csr.shape[0]
data = GraphData(csr, args.num_feats)
graph_name = os.path.basename(args.graph_file)
data = GraphData(csr, args.num_feats, graph_name)
csr = None
graph_name = args.graph_file
else:
data = load_data(args)
n_edges = data.graph.number_of_edges()
......@@ -67,6 +68,7 @@ def main(args):
n_test_samples))
# create GCN model
print('graph name: ' + graph_name)
g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem",
args.num_workers, False)
g.ndata['features'] = features
......
......@@ -305,9 +305,13 @@ class SharedMemoryStoreServer(object):
"""
def __init__(self, graph_data, edge_dir, graph_name, multigraph, num_workers, port):
self.server = None
graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
graph_idx.from_csr_matrix(indptr, indices, edge_dir, _get_graph_path(graph_name))
if isinstance(graph_data, GraphIndex):
graph_idx = graph_data
else:
graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
indptr, indices = _to_csr(graph_data, edge_dir, multigraph)
graph_idx.from_csr_matrix(utils.toindex(indptr), utils.toindex(indices),
edge_dir, _get_graph_path(graph_name))
self._graph = DGLGraph(graph_idx, multigraph=multigraph, readonly=True)
self._num_workers = num_workers
......@@ -331,7 +335,9 @@ class SharedMemoryStoreServer(object):
# RPC command: get the graph information from the graph store server.
def get_graph_info(graph_name):
assert graph_name == self._graph_name
return self._graph.number_of_nodes(), self._graph.number_of_edges(), \
# if the integers are larger than 2^31, xmlrpc can't handle them.
# we convert them to strings to send them to clients.
return str(self._graph.number_of_nodes()), str(self._graph.number_of_edges()), \
self._graph.is_multigraph, edge_dir
# RPC command: initialize node embedding in the server.
......@@ -532,6 +538,7 @@ class SharedMemoryDGLGraph(BaseGraphStore):
if self._worker_id < 0:
raise Exception('fail to get graph ' + graph_name + ' from the graph store')
num_nodes, num_edges, multigraph, edge_dir = self.proxy.get_graph_info(graph_name)
num_nodes, num_edges = int(num_nodes), int(num_edges)
graph_idx = GraphIndex(multigraph=multigraph, readonly=True)
graph_idx.from_shared_mem_csr_matrix(_get_graph_path(graph_name), num_nodes, num_edges, edge_dir)
......
......@@ -812,9 +812,9 @@ class GraphIndex(object):
Parameters
----------
indptr : a 1D tensor
indptr : utils.Index
index pointer in the CSR format
indices : a 1D tensor
indices : utils.Index
column index array in the CSR format
edge_dir : string
the edge direction. The supported option is "in" and "out".
......@@ -822,13 +822,9 @@ class GraphIndex(object):
the name of shared memory
"""
assert self.is_readonly()
indptr = utils.toindex(indptr)
indices = utils.toindex(indices)
edge_ids = utils.toindex(F.arange(0, len(indices)))
self._handle = _CAPI_DGLGraphCSRCreate(
indptr.todgltensor(),
indices.todgltensor(),
edge_ids.todgltensor(),
shared_mem_name,
self._multigraph,
edge_dir)
......
......@@ -152,11 +152,16 @@ DGL_REGISTER_GLOBAL("graph_index._CAPI_DGLGraphCSRCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const IdArray indptr = args[0];
const IdArray indices = args[1];
const IdArray edge_ids = args[2];
const std::string shared_mem_name = args[3];
const bool multigraph = static_cast<bool>(args[4]);
const std::string edge_dir = args[5];
const std::string shared_mem_name = args[2];
const bool multigraph = static_cast<bool>(args[3]);
const std::string edge_dir = args[4];
CSRPtr csr;
IdArray edge_ids = IdArray::Empty({indices->shape[0]},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
int64_t *edge_data = static_cast<int64_t *>(edge_ids->data);
for (size_t i = 0; i < edge_ids->shape[0]; i++)
edge_data[i] = i;
if (shared_mem_name.empty())
// TODO(minjie): The array copy here is unnecessary and adds extra overhead.
// However, with MXNet backend, the memory would be corrupted if we directly
......
......@@ -161,7 +161,7 @@ def test_load_csr():
# Load CSR normally.
idx = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
idx.from_csr_matrix(csr.indptr, csr.indices, 'out')
idx.from_csr_matrix(utils.toindex(csr.indptr), utils.toindex(csr.indices), 'out')
assert idx.number_of_nodes() == n
assert idx.number_of_edges() == csr.nnz
src, dst, eid = idx.edges()
......@@ -174,7 +174,8 @@ def test_load_csr():
# Shared memory isn't supported in Windows.
if os.name is not 'nt':
idx = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
idx.from_csr_matrix(csr.indptr, csr.indices, 'out', '/test_graph_struct')
idx.from_csr_matrix(utils.toindex(csr.indptr), utils.toindex(csr.indices),
'out', '/test_graph_struct')
assert idx.number_of_nodes() == n
assert idx.number_of_edges() == csr.nnz
src, dst, eid = idx.edges()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment