Commit eafcb7e7 authored by Minjie Wang's avatar Minjie Wang Committed by Da Zheng
Browse files

[Bugfix][MXNet] Fix edge order and builtin max bug in mx (#247)

* Fix edge order and builtin max bug in mx

* fix as requested
parent 71fa26ac
...@@ -105,6 +105,9 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -105,6 +105,9 @@ def sparse_matrix(data, index, shape, force_format=False):
SparseMatrix SparseMatrix
The framework-specific sparse matrix. It can be stored in any format The framework-specific sparse matrix. It can be stored in any format
unless force_format is True. unless force_format is True.
Tensor
The data convert index due to sparse format change.
None if no conversion is needed.
""" """
pass pass
......
...@@ -27,13 +27,24 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -27,13 +27,24 @@ def sparse_matrix(data, index, shape, force_format=False):
raise TypeError('MXNet backend only supports CSR format,' raise TypeError('MXNet backend only supports CSR format,'
' but COO format is forced.') ' but COO format is forced.')
coord = index[1] coord = index[1]
return nd.sparse.csr_matrix((data, (coord[0], coord[1])), # generate convert idx
# FIXME: cannot use int64
tmp_data = nd.arange(len(coord[0]), dtype=data.dtype, ctx=coord[0].context)
tmp_spmat = nd.sparse.csr_matrix((tmp_data, (coord[0], coord[1])),
tuple(shape), ctx=data.context) tuple(shape), ctx=data.context)
convert_idx = nd.cast(tmp_spmat.data, dtype='int64')
# shuffle the data
data = data[convert_idx]
spmat = nd.sparse.csr_matrix((data, tmp_spmat.indices, tmp_spmat.indptr),
tuple(shape), ctx=data.context)
return spmat, convert_idx
elif fmt == 'csr': elif fmt == 'csr':
indices = index[1] indices = index[1]
indptr = index[2] indptr = index[2]
return nd.sparse.csr_matrix((data, indices, indptr), spmat = nd.sparse.csr_matrix((data, indices, indptr),
tuple(shape), ctx=data.context) tuple(shape), ctx=data.context)
# No conversion is required.
return spmat, None
else: else:
raise TypeError('Invalid format: %s.' % fmt) raise TypeError('Invalid format: %s.' % fmt)
...@@ -73,7 +84,7 @@ def mean(input, dim): ...@@ -73,7 +84,7 @@ def mean(input, dim):
return nd.mean(input, axis=dim) return nd.mean(input, axis=dim)
def max(input, dim): def max(input, dim):
return nd.max(input, axis=dim).asnumpy()[0] return nd.max(input, axis=dim)
def cat(seq, dim): def cat(seq, dim):
return nd.concat(*seq, dim=dim) return nd.concat(*seq, dim=dim)
......
...@@ -24,7 +24,9 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -24,7 +24,9 @@ def sparse_matrix(data, index, shape, force_format=False):
if fmt != 'coo': if fmt != 'coo':
raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt) raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt)
# NOTE: use _sparse_coo_tensor_unsafe to avoid unnecessary boundary check # NOTE: use _sparse_coo_tensor_unsafe to avoid unnecessary boundary check
return th._sparse_coo_tensor_unsafe(index[1], data, shape) spmat = th._sparse_coo_tensor_unsafe(index[1], data, shape)
# No conversion is required.
return spmat, None
def sparse_matrix_indices(spmat): def sparse_matrix_indices(spmat):
return ('coo', spmat._indices()) return ('coo', spmat._indices())
......
...@@ -53,14 +53,16 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -53,14 +53,16 @@ class SrcMulEdgeMessageFunction(MessageFunction):
return _is_spmv_supported_edge_feat(g, self.edge_field) return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, edges): def __call__(self, edges):
src_data = edges.src[self.src_field] sdata = edges.src[self.src_field]
edata = edges.data[self.edge_field] edata = edges.data[self.edge_field]
if F.ndim(edata) == 1: # Due to the different broadcasting semantics of different backends,
# edge feature is a scalar, unsqueeze dims of len 1 # we need to broadcast the sdata and edata to be of the same rank.
src_dim = F.ndim(src_data) rank = max(F.ndim(sdata), F.ndim(edata))
new_eshape = (F.shape(edata)[0],) + (1,) * (src_dim - 1) sshape = F.shape(sdata)
edata = F.reshape(edata, new_eshape) eshape = F.shape(edata)
ret = self.mul_op(src_data, edata) sdata = F.reshape(sdata, sshape + (1,) * (rank - F.ndim(sdata)))
edata = F.reshape(edata, eshape + (1,) * (rank - F.ndim(edata)))
ret = self.mul_op(sdata, edata)
return {self.out_field : ret} return {self.out_field : ret}
@property @property
......
...@@ -2703,7 +2703,7 @@ class DGLGraph(object): ...@@ -2703,7 +2703,7 @@ class DGLGraph(object):
SparseTensor SparseTensor
The adjacency matrix. The adjacency matrix.
""" """
return self._graph.adjacency_matrix(transpose, ctx) return self._graph.adjacency_matrix(transpose, ctx)[0]
def incidence_matrix(self, type, ctx=F.cpu()): def incidence_matrix(self, type, ctx=F.cpu()):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
...@@ -2745,7 +2745,7 @@ class DGLGraph(object): ...@@ -2745,7 +2745,7 @@ class DGLGraph(object):
SparseTensor SparseTensor
The incidence matrix. The incidence matrix.
""" """
return self._graph.incidence_matrix(type, ctx) return self._graph.incidence_matrix(type, ctx)[0]
def line_graph(self, backtracking=True, shared=False): def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
......
...@@ -484,28 +484,6 @@ class GraphIndex(object): ...@@ -484,28 +484,6 @@ class GraphIndex(object):
induced_nodes = utils.toindex(rst(1)) induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e) return SubgraphIndex(rst(0), self, induced_nodes, e)
def adjacency_matrix_indices_and_shape(self, transpose=False):
"""Return the indices and dense shape of adjacency matrix representation of
this graph.
utils.CtxCachedObject
An object that returns indices tensor given context.
tuple
Dense shape of the adjacency matrix
"""
if not 'adj_ind_shape' in self._cache:
src, dst, _ = self.edges(sorted=False)
src = F.unsqueeze(src.tousertensor(), 0)
dst = F.unsqueeze(dst.tousertensor(), 0)
n = self.number_of_nodes()
if transpose:
idx = F.cat([src, dst], dim=0)
else:
idx = F.cat([dst, src], dim=0)
cached_idx = utils.CtxCachedObject(lambda ctx: F.copy_to(idx, ctx))
self._cache['adj_ind_shape'] = (cached_idx, (n, n))
return self._cache['adj_ind_shape']
def adjacency_matrix(self, transpose, ctx): def adjacency_matrix(self, transpose, ctx):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
...@@ -526,6 +504,9 @@ class GraphIndex(object): ...@@ -526,6 +504,9 @@ class GraphIndex(object):
------- -------
SparseTensor SparseTensor
The adjacency matrix. The adjacency matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError('Expect bool value for "transpose" arg,'
...@@ -543,8 +524,9 @@ class GraphIndex(object): ...@@ -543,8 +524,9 @@ class GraphIndex(object):
m = self.number_of_edges() m = self.number_of_edges()
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
adj = F.sparse_matrix(dat, ('coo', idx), (n, n)) adj, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, n))
return adj shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return adj, shuffle_idx
def incidence_matrix(self, type, ctx): def incidence_matrix(self, type, ctx):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
...@@ -577,6 +559,9 @@ class GraphIndex(object): ...@@ -577,6 +559,9 @@ class GraphIndex(object):
------- -------
SparseTensor SparseTensor
The incidence matrix. The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
src, dst, eid = self.edges(sorted=False) src, dst, eid = self.edges(sorted=False)
src = src.tousertensor(ctx) # the index of the ctx will be cached src = src.tousertensor(ctx) # the index of the ctx will be cached
...@@ -590,14 +575,14 @@ class GraphIndex(object): ...@@ -590,14 +575,14 @@ class GraphIndex(object):
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif type == 'out': elif type == 'out':
row = F.unsqueeze(src, 0) row = F.unsqueeze(src, 0)
col = F.unsqueeze(eid, 0) col = F.unsqueeze(eid, 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((m,), dtype=F.float32, ctx=ctx) dat = F.ones((m,), dtype=F.float32, ctx=ctx)
inc = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
elif type == 'both': elif type == 'both':
# create index # create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0) row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
...@@ -611,10 +596,11 @@ class GraphIndex(object): ...@@ -611,10 +596,11 @@ class GraphIndex(object):
x[diagonal] = 0 x[diagonal] = 0
y[diagonal] = 0 y[diagonal] = 0
dat = F.cat([x, y], dim=0) dat = F.cat([x, y], dim=0)
inc = F.sparse_matrix(dat, ('coo', idx), (n, m)) inc, shuffle_idx = F.sparse_matrix(dat, ('coo', idx), (n, m))
else: else:
raise DGLError('Invalid incidence matrix type: %s' % str(type)) raise DGLError('Invalid incidence matrix type: %s' % str(type))
return inc shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return inc, shuffle_idx
def to_networkx(self): def to_networkx(self):
"""Convert to networkx graph. """Convert to networkx graph.
......
...@@ -8,7 +8,7 @@ import scipy.sparse as sp ...@@ -8,7 +8,7 @@ import scipy.sparse as sp
from ._ffi.function import _init_api from ._ffi.function import _init_api
from . import backend as F from . import backend as F
from . import utils from . import utils
from .base import ALL, is_all from .base import ALL, is_all, dgl_warning
class ImmutableGraphIndex(object): class ImmutableGraphIndex(object):
"""Graph index object on immutable graphs. """Graph index object on immutable graphs.
...@@ -473,11 +473,16 @@ class ImmutableGraphIndex(object): ...@@ -473,11 +473,16 @@ class ImmutableGraphIndex(object):
------- -------
utils.CtxCachedObject utils.CtxCachedObject
An object that returns tensor given context. An object that returns tensor given context.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
def get_adj(ctx): def get_adj(ctx):
new_mat = self._sparse.adjacency_matrix(transpose) new_mat = self._sparse.adjacency_matrix(transpose)
return F.copy_to(new_mat, ctx) return F.copy_to(new_mat, ctx)
return self._sparse.adjacency_matrix(transpose, ctx) # FIXME(minjie): calculate the shuffle index
dgl_warning('Shuffle index is not correctly computed. SPMV with edge feature might fail!!')
return self._sparse.adjacency_matrix(transpose, ctx), None
def incidence_matrix(self, type, ctx): def incidence_matrix(self, type, ctx):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
...@@ -510,6 +515,9 @@ class ImmutableGraphIndex(object): ...@@ -510,6 +515,9 @@ class ImmutableGraphIndex(object):
------- -------
SparseTensor SparseTensor
The incidence matrix. The incidence matrix.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
raise Exception('immutable graph doesn\'t support incidence_matrix for now.') raise Exception('immutable graph doesn\'t support incidence_matrix for now.')
...@@ -540,9 +548,11 @@ class ImmutableGraphIndex(object): ...@@ -540,9 +548,11 @@ class ImmutableGraphIndex(object):
nx_graph : networkx.DiGraph nx_graph : networkx.DiGraph
The nx graph The nx graph
""" """
assert isinstance(nx_graph, nx.DiGraph), "The input graph has to be a NetworkX DiGraph." if not isinstance(nx_graph, nx.DiGraph):
nx_graph = nx.DiGraph(nx_graph)
# We store edge Ids as an edge attribute. # We store edge Ids as an edge attribute.
out_mat = nx.convert_matrix.to_scipy_sparse_matrix(nx_graph, format='coo') nodelist = list(range(nx_graph.number_of_nodes()))
out_mat = nx.convert_matrix.to_scipy_sparse_matrix(nx_graph, nodelist=nodelist, format='coo')
self._sparse.from_coo_matrix(out_mat) self._sparse.from_coo_matrix(out_mat)
def from_scipy_sparse_matrix(self, adj): def from_scipy_sparse_matrix(self, adj):
......
...@@ -307,7 +307,8 @@ class SPMVWithDataExecutor(Executor): ...@@ -307,7 +307,8 @@ class SPMVWithDataExecutor(Executor):
spA = spA_ctxobj.get(ctx) spA = spA_ctxobj.get(ctx)
spidx = F.sparse_matrix_indices(spA) spidx = F.sparse_matrix_indices(spA)
shape = F.shape(spA) shape = F.shape(spA)
spA = F.sparse_matrix(A_data, spidx, shape) # shuffle index is not used
spA, _ = F.sparse_matrix(A_data, spidx, shape)
if F.ndim(B) == 1: if F.ndim(B) == 1:
# B is a vector, append a (1,) dim at the end # B is a vector, append a (1,) dim at the end
......
...@@ -489,9 +489,9 @@ def _gen_send_reduce( ...@@ -489,9 +489,9 @@ def _gen_send_reduce(
uv_getter : callable uv_getter : callable
A function that returns a pair of var.IDX (u, v) for the triggered edges. A function that returns a pair of var.IDX (u, v) for the triggered edges.
adj_creator : callable adj_creator : callable
A function that returns var.SPMAT that represents the adjmat. A function that returns the adjmat and the shuffle index.
inc_creator : callable inc_creator : callable
A function that returns var.SPMAT that represents the incmat. A function that returns the incmat and the shuffle index.
Returns Returns
------- -------
......
...@@ -80,9 +80,9 @@ def analyze_e2v_spmv(graph, rfunc): ...@@ -80,9 +80,9 @@ def analyze_e2v_spmv(graph, rfunc):
rfunc_left.append(rfn) rfunc_left.append(rfn)
return spmv_rfunc, rfunc_left return spmv_rfunc, rfunc_left
def gen_v2v_spmv_schedule(adjmat, spmv_pairs, nf, ef, eid, out): def gen_v2v_spmv_schedule(adj, spmv_pairs, nf, ef, eid, out):
""" """
adjmat : sparse matrix adj : tuple (sparse matrix, utils.Index)
spmv_pairs : list of pair spmv_pairs : list of pair
nf : var.Var nf : var.Var
input node features input node features
...@@ -93,9 +93,12 @@ def gen_v2v_spmv_schedule(adjmat, spmv_pairs, nf, ef, eid, out): ...@@ -93,9 +93,12 @@ def gen_v2v_spmv_schedule(adjmat, spmv_pairs, nf, ef, eid, out):
out : var.Var out : var.Var
output node features output node features
""" """
adjmat, shuffle_idx = adj
adj_var = var.SPMAT(adjmat) adj_var = var.SPMAT(adjmat)
if shuffle_idx is not None:
new_eid = utils.reorder_index(eid.data, shuffle_idx)
eid = var.IDX(new_eid)
for mfn, rfn in spmv_pairs: for mfn, rfn in spmv_pairs:
#print('v2v mfn=%s rfn=%s' % (mfn.name, rfn.name))
if mfn.use_edge_feature: if mfn.use_edge_feature:
ftedge = ir.READ(ef, eid, var.STR(mfn.edge_field)) ftedge = ir.READ(ef, eid, var.STR(mfn.edge_field))
ftsrc = ir.READ_COL(nf, var.STR(mfn.src_field)) ftsrc = ir.READ_COL(nf, var.STR(mfn.src_field))
...@@ -108,15 +111,15 @@ def gen_v2v_spmv_schedule(adjmat, spmv_pairs, nf, ef, eid, out): ...@@ -108,15 +111,15 @@ def gen_v2v_spmv_schedule(adjmat, spmv_pairs, nf, ef, eid, out):
def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out): def gen_e2v_spmv_schedule(inc, spmv_rfunc, mf, out):
""" """
inc : sparse matrix inc : tuple (sparse matrix, utils.Index)
The incidence matrix
spmv_rfunc : list of builtin reducers spmv_rfunc : list of builtin reducers
mf : var.Var mf : var.Var
Variable for message frame. Variable for message frame.
out : var.Var out : var.Var
Variable for output reduced features. Variable for output reduced features.
""" """
inc_var = var.SPMAT(inc) incmat, _ = inc
inc_var = var.SPMAT(incmat)
for rfn in spmv_rfunc: for rfn in spmv_rfunc:
ftmsg = ir.READ_COL(mf, var.STR(rfn.msg_field)) ftmsg = ir.READ_COL(mf, var.STR(rfn.msg_field))
ftdst = ir.SPMV(inc_var, ftmsg) ftdst = ir.SPMV(inc_var, ftmsg)
...@@ -134,10 +137,14 @@ def build_adj_matrix_graph(graph): ...@@ -134,10 +137,14 @@ def build_adj_matrix_graph(graph):
------- -------
utils.CtxCachedObject utils.CtxCachedObject
Get be used to get adjacency matrix on the provided ctx. Get be used to get adjacency matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
return utils.CtxCachedObject(lambda ctx : graph.adjacency_matrix(ctx=ctx)) adjmat, shuffle_idx = graph._graph.adjacency_matrix(transpose=False, ctx=F.cpu())
return utils.CtxCachedObject(lambda ctx : F.copy_to(adjmat, ctx)), shuffle_idx
def build_adj_matrix_index_uv(graph, edges, reduce_nodes): def _build_adj_matrix_index_uv(graph, edges, reduce_nodes):
"""Build adj matrix index and shape using the given (u, v) edges. """Build adj matrix index and shape using the given (u, v) edges.
The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes The matrix is of shape (len(reduce_nodes), n), where n is the number of nodes
...@@ -198,15 +205,19 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes): ...@@ -198,15 +205,19 @@ def build_adj_matrix_uv(graph, edges, reduce_nodes):
Returns Returns
------- -------
utils.CtxCachedObject utils.CtxCachedObject
Get be used to get adjacency matrix on the provided ctx. Get be used to get adjacency matrix and on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
sp_idx, shape = build_adj_matrix_index_uv(graph, edges, reduce_nodes) sp_idx, shape = _build_adj_matrix_index_uv(graph, edges, reduce_nodes)
u, v = edges u, v = edges
nnz = len(u) nnz = len(u)
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
mat = F.sparse_matrix(dat, sp_idx, shape) mat, shuffle_idx = F.sparse_matrix(dat, sp_idx, shape)
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)) shuffle_idx = utils.toindex(shuffle_idx) if shuffle_idx is not None else None
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)), shuffle_idx
def build_inc_matrix_graph(graph): def build_inc_matrix_graph(graph):
"""Build incidence matrix. """Build incidence matrix.
...@@ -220,8 +231,13 @@ def build_inc_matrix_graph(graph): ...@@ -220,8 +231,13 @@ def build_inc_matrix_graph(graph):
------- -------
utils.CtxCachedObject utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx. Get be used to get incidence matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
return utils.CtxCachedObject(lambda ctx : graph.incidence_matrix(type='in', ctx=ctx)) incmat, _ = graph._graph.incidence_matrix(type='in', ctx=F.cpu())
# inc mat will not use data tensor so conversion index is not needed
return utils.CtxCachedObject(lambda ctx : F.copy_to(incmat, ctx)), None
def build_inc_matrix_eid(m, eid, dst, reduce_nodes): def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
"""Build incidence matrix using edge id and edge dst nodes. """Build incidence matrix using edge id and edge dst nodes.
...@@ -269,6 +285,9 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes): ...@@ -269,6 +285,9 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
------- -------
utils.CtxCachedObject utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx. Get be used to get incidence matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True) new2old, old2new = utils.build_relabel_map(reduce_nodes, sorted=True)
dst = dst.tousertensor() dst = dst.tousertensor()
...@@ -283,8 +302,9 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes): ...@@ -283,8 +302,9 @@ def build_inc_matrix_eid(m, eid, dst, reduce_nodes):
# create dat tensor # create dat tensor
nnz = len(eid) nnz = len(eid)
dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu()) dat = F.ones((nnz,), dtype=F.float32, ctx=F.cpu())
mat = F.sparse_matrix(dat, ('coo', idx), (n, m)) mat, _ = F.sparse_matrix(dat, ('coo', idx), (n, m))
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)) # inc mat will not use data tensor so conversion index is not needed
return utils.CtxCachedObject(lambda ctx : F.copy_to(mat, ctx)), None
def build_inc_matrix_dst(dst, reduce_nodes): def build_inc_matrix_dst(dst, reduce_nodes):
"""Build incidence matrix using only edge destinations. """Build incidence matrix using only edge destinations.
...@@ -318,6 +338,9 @@ def build_inc_matrix_dst(dst, reduce_nodes): ...@@ -318,6 +338,9 @@ def build_inc_matrix_dst(dst, reduce_nodes):
------- -------
utils.CtxCachedObject utils.CtxCachedObject
Get be used to get incidence matrix on the provided ctx. Get be used to get incidence matrix on the provided ctx.
utils.Index
A index for data shuffling due to sparse format change. Return None
if shuffle is not required.
""" """
eid = utils.toindex(F.arange(0, len(dst))) eid = utils.toindex(F.arange(0, len(dst)))
return build_inc_matrix_eid(len(eid), eid, dst, reduce_nodes) return build_inc_matrix_eid(len(eid), eid, dst, reduce_nodes)
...@@ -237,8 +237,8 @@ def build_relabel_map(x, sorted=False): ...@@ -237,8 +237,8 @@ def build_relabel_map(x, sorted=False):
unique_x, _ = F.sort_1d(F.unique(x)) unique_x, _ = F.sort_1d(F.unique(x))
else: else:
unique_x = x unique_x = x
map_len = int(F.max(unique_x, dim=0)) + 1 map_len = int(F.asnumpy(F.max(unique_x, dim=0))) + 1
old_to_new = F.zeros(map_len, dtype=F.int64, ctx=F.cpu()) old_to_new = F.zeros((map_len,), dtype=F.int64, ctx=F.cpu())
F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x))) F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x)))
return unique_x, old_to_new return unique_x, old_to_new
...@@ -334,30 +334,20 @@ def reorder(dict_like, index): ...@@ -334,30 +334,20 @@ def reorder(dict_like, index):
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
def build_coo_sparse_matrix(dat, row, col, dense_shape): def reorder_index(idx, order):
"""Build coo sparse matrix """Reorder the idx according to the given order
Parameters Parameters
---------- ----------
dat: Tensor idx : utils.Index
Data. The index to be reordered.
row: Tensor order : utils.Index
Row index. The order to follow.
col: Tensor
Column index.
dense_shape: list or tuple of two integer
Dense shape of the sparse matrix
Returns
-------
SparseTensor
The sparse matrix.
""" """
nnz = len(row) idx = idx.tousertensor()
row = F.unsqueeze(row, 0) order = order.tousertensor()
col = F.unsqueeze(col, 0) new_idx = F.gather_row(idx, order)
idx = F.cat([row, col], dim=0) return toindex(new_idx)
return F.sparse_matrix(dat, ('coo', idx), dense_shape)
def is_iterable(obj): def is_iterable(obj):
"""Return true if the object is an iterable.""" """Return true if the object is an iterable."""
......
...@@ -14,8 +14,8 @@ def generate_rand_graph(n): ...@@ -14,8 +14,8 @@ def generate_rand_graph(n):
return g, ig return g, ig
def check_graph_equal(g1, g2): def check_graph_equal(g1, g2):
adj1 = g1.adjacency_matrix(transpose=False, ctx=mx.cpu()) != 0 adj1 = g1.adjacency_matrix(transpose=False, ctx=mx.cpu())[0] != 0
adj2 = g2.adjacency_matrix(transpose=False, ctx=mx.cpu()) != 0 adj2 = g2.adjacency_matrix(transpose=False, ctx=mx.cpu())[0] != 0
assert mx.nd.sum(adj1 - adj2).asnumpy() == 0 assert mx.nd.sum(adj1 - adj2).asnumpy() == 0
def test_graph_gen(): def test_graph_gen():
......
...@@ -26,6 +26,7 @@ def generate_graph2(n): ...@@ -26,6 +26,7 @@ def generate_graph2(n):
arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64) arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
g1 = dgl.DGLGraph(arr, readonly=True) g1 = dgl.DGLGraph(arr, readonly=True)
g2 = dgl.DGLGraph(arr, readonly=True) g2 = dgl.DGLGraph(arr, readonly=True)
num_nodes = g1.number_of_nodes() num_nodes = g1.number_of_nodes()
g1.set_n_repr({'f1' : mx.nd.random.normal(shape=(num_nodes,)), g1.set_n_repr({'f1' : mx.nd.random.normal(shape=(num_nodes,)),
'f2' : mx.nd.random.normal(shape=(num_nodes, D))}) 'f2' : mx.nd.random.normal(shape=(num_nodes, D))})
...@@ -308,9 +309,116 @@ def test_send_and_recv_multi_fn(): ...@@ -308,9 +309,116 @@ def test_send_and_recv_multi_fn():
v2 = g.ndata['v2'] v2 = g.ndata['v2']
assert np.allclose(v1.asnumpy(), v2.asnumpy(), rtol=1e-05, atol=1e-05) assert np.allclose(v1.asnumpy(), v2.asnumpy(), rtol=1e-05, atol=1e-05)
############################ Copy from torch
D = 5
def simple_graph():
g = dgl.DGLGraph()
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
g.set_n_repr({'f1' : mx.nd.random.normal(shape=(10,)), 'f2' : mx.nd.random.normal(shape=(10, D))})
weights = mx.nd.random.normal(shape=(17,))
g.set_e_repr({'e1': weights, 'e2': mx.nd.expand_dims(weights, 1)})
return g
def test_v2v_update_all_sum():
def _test(fld):
def message_func(edges):
return {'m' : edges.src[fld]}
def message_func_edge(edges):
if len(edges.src[fld].shape) == 1:
return {'m' : edges.src[fld] * edges.data['e1']}
else:
return {'m' : edges.src[fld] * edges.data['e2']}
def reduce_func(nodes):
return {fld : mx.nd.sum(nodes.mailbox['m'], axis=1)}
def apply_func(nodes):
return {fld : 2 * nodes.data[fld]}
g = simple_graph()
# update all
v1 = g.ndata[fld]
g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func)
v3 = g.ndata[fld]
assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
# update all with edge weights
v1 = g.ndata[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v3 = g.ndata[fld].squeeze()
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func)
v4 = g.ndata[fld]
assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
assert np.allclose(v3.asnumpy(), v4.asnumpy(), rtol=1e-05, atol=1e-05)
# test 1d node features
_test('f1')
# test 2d node features
_test('f2')
def test_v2v_update_all_max():
def _test(fld):
def message_func(edges):
return {'m' : edges.src[fld]}
def message_func_edge(edges):
if len(edges.src[fld].shape) == 1:
return {'m' : edges.src[fld] * edges.data['e1']}
else:
return {'m' : edges.src[fld] * edges.data['e2']}
def reduce_func(nodes):
return {fld : mx.nd.max(nodes.mailbox['m'], axis=1)}
def apply_func(nodes):
return {fld : 2 * nodes.data[fld]}
g = simple_graph()
# update all
v1 = g.ndata[fld]
g.update_all(fn.copy_src(src=fld, out='m'), fn.max(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func)
v3 = g.ndata[fld]
assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
# update all with edge weights
v1 = g.ndata[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
fn.max(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
fn.max(msg='m', out=fld), apply_func)
v3 = g.ndata[fld].squeeze()
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func)
v4 = g.ndata[fld]
assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
assert np.allclose(v3.asnumpy(), v4.asnumpy(), rtol=1e-05, atol=1e-05)
# test 1d node features
_test('f1')
# test 2d node features
_test('f2')
############################ Copy from torch
if __name__ == '__main__': if __name__ == '__main__':
test_update_all() test_update_all()
test_pull() test_pull()
test_send_and_recv() test_send_and_recv()
test_update_all_multi_fn() test_update_all_multi_fn()
test_send_and_recv_multi_fn() test_send_and_recv_multi_fn()
test_v2v_update_all_sum()
test_v2v_update_all_max()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment