Unverified Commit 048f6d7a authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[GraphIndex] refactor graph caching (#150)

* refactor graph caching

* fix mx test

* fix typo
parent a9ffb59e
...@@ -77,6 +77,9 @@ def tensor(data, dtype=None): ...@@ -77,6 +77,9 @@ def tensor(data, dtype=None):
def sparse_matrix(data, index, shape, force_format=False): def sparse_matrix(data, index, shape, force_format=False):
"""Create a sparse matrix. """Create a sparse matrix.
NOTE: Please make sure that the data and index tensors are not
copied. This is critical to the performance.
Parameters Parameters
---------- ----------
data : Tensor data : Tensor
...@@ -482,7 +485,7 @@ def reshape(input, shape): ...@@ -482,7 +485,7 @@ def reshape(input, shape):
""" """
pass pass
def zeros(shape, dtype): def zeros(shape, dtype, ctx):
"""Create a zero tensor. """Create a zero tensor.
Parameters Parameters
...@@ -491,6 +494,8 @@ def zeros(shape, dtype): ...@@ -491,6 +494,8 @@ def zeros(shape, dtype):
The tensor shape. The tensor shape.
dtype : data type dtype : data type
It should be one of the values in the data type dict. It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns Returns
------- -------
...@@ -499,7 +504,7 @@ def zeros(shape, dtype): ...@@ -499,7 +504,7 @@ def zeros(shape, dtype):
""" """
pass pass
def ones(shape, dtype): def ones(shape, dtype, ctx):
"""Create a one tensor. """Create a one tensor.
Parameters Parameters
...@@ -508,6 +513,8 @@ def ones(shape, dtype): ...@@ -508,6 +513,8 @@ def ones(shape, dtype):
The tensor shape. The tensor shape.
dtype : data type dtype : data type
It should be one of the values in the data type dict. It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns Returns
------- -------
......
...@@ -268,7 +268,7 @@ class ImmutableGraphIndex(object): ...@@ -268,7 +268,7 @@ class ImmutableGraphIndex(object):
induced_es.append(induced_e) induced_es.append(induced_e)
return gis, induced_ns, induced_es return gis, induced_ns, induced_es
def adjacency_matrix(self, transpose=False): def adjacency_matrix(self, transpose, ctx):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination By default, a row of returned adjacency matrix represents the destination
...@@ -281,6 +281,8 @@ class ImmutableGraphIndex(object): ...@@ -281,6 +281,8 @@ class ImmutableGraphIndex(object):
---------- ----------
transpose : bool transpose : bool
A flag to tranpose the returned adjacency matrix. A flag to tranpose the returned adjacency matrix.
ctx : context
The device context of the returned matrix.
Returns Returns
------- -------
...@@ -294,7 +296,7 @@ class ImmutableGraphIndex(object): ...@@ -294,7 +296,7 @@ class ImmutableGraphIndex(object):
indices = mat.indices indices = mat.indices
indptr = mat.indptr indptr = mat.indptr
data = mx.nd.ones(indices.shape, dtype=np.float32) data = mx.nd.ones(indices.shape, dtype=np.float32, ctx=ctx)
return mx.nd.sparse.csr_matrix((data, indices, indptr), shape=mat.shape) return mx.nd.sparse.csr_matrix((data, indices, indptr), shape=mat.shape)
def from_coo_matrix(self, out_coo): def from_coo_matrix(self, out_coo):
......
...@@ -114,11 +114,11 @@ def reshape(input, shape): ...@@ -114,11 +114,11 @@ def reshape(input, shape):
# NOTE: the input cannot be a symbol # NOTE: the input cannot be a symbol
return nd.reshape(input ,shape) return nd.reshape(input ,shape)
def zeros(shape, dtype): def zeros(shape, dtype, ctx):
return nd.zeros(shape, dtype=dtype) return nd.zeros(shape, dtype=dtype, ctx=ctx)
def ones(shape, dtype): def ones(shape, dtype, ctx):
return nd.ones(shape, dtype=dtype) return nd.ones(shape, dtype=dtype, ctx=ctx)
def spmm(x, y): def spmm(x, y):
return nd.dot(x, y) return nd.dot(x, y)
......
...@@ -23,7 +23,8 @@ def sparse_matrix(data, index, shape, force_format=False): ...@@ -23,7 +23,8 @@ def sparse_matrix(data, index, shape, force_format=False):
fmt = index[0] fmt = index[0]
if fmt != 'coo': if fmt != 'coo':
raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt) raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt)
return th.sparse.FloatTensor(index[1], data, shape) # NOTE: use _sparse_coo_tensor_unsafe to avoid unnecessary boundary check
return th._sparse_coo_tensor_unsafe(index[1], data, shape)
def sparse_matrix_indices(spmat): def sparse_matrix_indices(spmat):
return ('coo', spmat._indices()) return ('coo', spmat._indices())
...@@ -98,11 +99,11 @@ def unsqueeze(input, dim): ...@@ -98,11 +99,11 @@ def unsqueeze(input, dim):
def reshape(input, shape): def reshape(input, shape):
return th.reshape(input ,shape) return th.reshape(input ,shape)
def zeros(shape, dtype): def zeros(shape, dtype, ctx):
return th.zeros(shape, dtype=dtype) return th.zeros(shape, dtype=dtype, device=ctx)
def ones(shape, dtype): def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype) return th.ones(shape, dtype=dtype, device=ctx)
def spmm(x, y): def spmm(x, y):
return th.spmm(x, y) return th.spmm(x, y)
......
...@@ -183,7 +183,7 @@ class Frame(MutableMapping): ...@@ -183,7 +183,7 @@ class Frame(MutableMapping):
dgl_warning('Initializer is not set. Use zero initializer instead.' dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to' ' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.') ' explicitly specify which initializer to use.')
self._initializer = lambda shape, dtype: F.zeros(shape, dtype) self._initializer = lambda shape, dtype, ctx: F.zeros(shape, dtype, ctx)
def set_initializer(self, initializer): def set_initializer(self, initializer):
"""Set the initializer for empty values. """Set the initializer for empty values.
...@@ -283,9 +283,7 @@ class Frame(MutableMapping): ...@@ -283,9 +283,7 @@ class Frame(MutableMapping):
' one column in the frame so number of rows can be inferred.' % name) ' one column in the frame so number of rows can be inferred.' % name)
if self.initializer is None: if self.initializer is None:
self._warn_and_set_initializer() self._warn_and_set_initializer()
# TODO(minjie): directly init data on the targer device. init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx)
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype)
init_data = F.copy_to(init_data, ctx)
self._columns[name] = Column(init_data, scheme) self._columns[name] = Column(init_data, scheme)
def update_column(self, name, data): def update_column(self, name, data):
...@@ -601,10 +599,10 @@ class FrameRef(MutableMapping): ...@@ -601,10 +599,10 @@ class FrameRef(MutableMapping):
for key in self._frame: for key in self._frame:
scheme = self._frame[key].scheme scheme = self._frame[key].scheme
ctx = F.context(self._frame[key].data)
if self._frame.initializer is None: if self._frame.initializer is None:
self._frame._warn_and_set_initializer() self._frame._warn_and_set_initializer()
new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype) new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype, ctx)
feat_placeholders[key] = new_data feat_placeholders[key] = new_data
self.append(feat_placeholders) self.append(feat_placeholders)
......
...@@ -733,7 +733,8 @@ class DGLGraph(object): ...@@ -733,7 +733,8 @@ class DGLGraph(object):
def set_n_initializer(self, initializer): def set_n_initializer(self, initializer):
"""Set the initializer for empty node features. """Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type. Initializer is a callable that returns a tensor given the shape, data type
and device context.
Parameters Parameters
---------- ----------
...@@ -745,7 +746,8 @@ class DGLGraph(object): ...@@ -745,7 +746,8 @@ class DGLGraph(object):
def set_e_initializer(self, initializer): def set_e_initializer(self, initializer):
"""Set the initializer for empty edge features. """Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type. Initializer is a callable that returns a tensor given the shape, data type
and device context.
Parameters Parameters
---------- ----------
...@@ -1509,12 +1511,20 @@ class DGLGraph(object): ...@@ -1509,12 +1511,20 @@ class DGLGraph(object):
self._edge_frame.num_rows, self._edge_frame.num_rows,
reduce_func) reduce_func)
def adjacency_matrix(self, ctx=F.cpu()): def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters Parameters
---------- ----------
ctx : optional transpose : bool, optional (default=False)
A flag to tranpose the returned adjacency matrix.
ctx : context, optional (default=cpu)
The context of returned adjacency matrix. The context of returned adjacency matrix.
Returns Returns
...@@ -1522,7 +1532,10 @@ class DGLGraph(object): ...@@ -1522,7 +1532,10 @@ class DGLGraph(object):
sparse_tensor sparse_tensor
The adjacency matrix. The adjacency matrix.
""" """
return self._graph.adjacency_matrix().get(ctx) if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
return self._graph.adjacency_matrix(transpose, ctx)
def incidence_matrix(self, oriented=False, ctx=F.cpu()): def incidence_matrix(self, oriented=False, ctx=F.cpu()):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
...@@ -1540,7 +1553,10 @@ class DGLGraph(object): ...@@ -1540,7 +1553,10 @@ class DGLGraph(object):
sparse_tensor sparse_tensor
The incidence matrix. The incidence matrix.
""" """
return self._graph.incidence_matrix(oriented).get(ctx) if not isinstance(oriented, bool):
raise DGLError('Expect bool value for "oriented" arg,'
' but got %s.' % (type(oriented)))
return self._graph.incidence_matrix(oriented, ctx)
def line_graph(self, backtracking=True, shared=False): def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
......
...@@ -7,6 +7,7 @@ import scipy ...@@ -7,6 +7,7 @@ import scipy
from ._ffi.base import c_array from ._ffi.base import c_array
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError
from . import backend as F from . import backend as F
from . import utils from . import utils
from .immutable_graph_index import create_immutable_graph_index from .immutable_graph_index import create_immutable_graph_index
...@@ -347,11 +348,14 @@ class GraphIndex(object): ...@@ -347,11 +348,14 @@ class GraphIndex(object):
utils.Index utils.Index
The edge ids. The edge ids.
""" """
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted) key = 'edges_s%d' % sorted
src = utils.toindex(edge_array(0)) if key not in self._cache:
dst = utils.toindex(edge_array(1)) edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
eid = utils.toindex(edge_array(2)) src = utils.toindex(edge_array(0))
return src, dst, eid dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
self._cache[key] = (src, dst, eid)
return self._cache[key]
def in_degree(self, v): def in_degree(self, v):
"""Return the in degree of the node. """Return the in degree of the node.
...@@ -470,7 +474,7 @@ class GraphIndex(object): ...@@ -470,7 +474,7 @@ class GraphIndex(object):
induced_nodes = utils.toindex(rst(1)) induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e) return SubgraphIndex(rst(0), self, induced_nodes, e)
def adjacency_matrix(self, transpose=False): def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination By default, a row of returned adjacency matrix represents the destination
...@@ -481,31 +485,30 @@ class GraphIndex(object): ...@@ -481,31 +485,30 @@ class GraphIndex(object):
Parameters Parameters
---------- ----------
transpose : bool transpose : bool, optional (default=False)
A flag to tranpose the returned adjacency matrix. A flag to tranpose the returned adjacency matrix.
Returns Returns
------- -------
utils.CtxCachedObject SparseTensor
An object that returns tensor given context. The adjacency matrix.
""" """
key = 'transposed adj' if transpose else 'adj' src, dst, _ = self.edges(sorted=False)
if not key in self._cache: src = src.tousertensor(ctx) # the index of the ctx will be cached
src, dst, _ = self.edges(sorted=False) dst = dst.tousertensor(ctx) # the index of the ctx will be cached
src = F.unsqueeze(src.tousertensor(), 0) src = F.unsqueeze(src, dim=0)
dst = F.unsqueeze(dst.tousertensor(), 0) dst = F.unsqueeze(dst, dim=0)
if transpose: if transpose:
idx = F.cat([src, dst], dim=0) idx = F.cat([src, dst], dim=0)
else: else:
idx = F.cat([dst, src], dim=0) idx = F.cat([dst, src], dim=0)
n = self.number_of_nodes() n = self.number_of_nodes()
# FIXME(minjie): data type # FIXME(minjie): data type
dat = F.ones((self.number_of_edges(),), dtype=F.float32) dat = F.ones((self.number_of_edges(),), dtype=F.float32, ctx=ctx)
mat = F.sparse_matrix(dat, ('coo', idx), (n, n)) adj = F.sparse_matrix(dat, ('coo', idx), (n, n))
self._cache[key] = utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx)) return adj
return self._cache[key]
def incidence_matrix(self, oriented=False, ctx=F.cpu()):
def incidence_matrix(self, oriented=False):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
Parameters Parameters
...@@ -515,38 +518,35 @@ class GraphIndex(object): ...@@ -515,38 +518,35 @@ class GraphIndex(object):
Returns Returns
------- -------
utils.CtxCachedObject SparseTensor
An object that returns tensor given context. The incidence matrix.
""" """
key = ('oriented ' if oriented else '') + 'incidence matrix' src, dst, eid = self.edges(sorted=False)
if not key in self._cache: src = src.tousertensor(ctx) # the index of the ctx will be cached
src, dst, _ = self.edges(sorted=False) dst = dst.tousertensor(ctx) # the index of the ctx will be cached
src = src.tousertensor() eid = eid.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor() n = self.number_of_nodes()
m = self.number_of_edges() m = self.number_of_edges()
eid = F.arange(0, m) # create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0) row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0) col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0) idx = F.cat([row, col], dim=0)
# create data
diagonal = (src == dst) diagonal = (src == dst)
if oriented: if oriented:
# FIXME(minjie): data type # FIXME(minjie): data type
x = -F.ones((m,), dtype=F.float32) x = -F.ones((m,), dtype=F.float32, ctx=ctx)
y = F.ones((m,), dtype=F.float32) y = F.ones((m,), dtype=F.float32, ctx=ctx)
x[diagonal] = 0 x[diagonal] = 0
y[diagonal] = 0 y[diagonal] = 0
dat = F.cat([x, y], dim=0) dat = F.cat([x, y], dim=0)
else: else:
# FIXME(minjie): data type # FIXME(minjie): data type
x = F.ones((m,), dtype=F.float32) x = F.ones((m,), dtype=F.float32, ctx=ctx)
x[diagonal] = 0 x[diagonal] = 0
dat = F.cat([x, x], dim=0) dat = F.cat([x, x], dim=0)
n = self.number_of_nodes() inc = F.sparse_matrix(dat, ('coo', idx), (n, m))
mat = F.sparse_matrix(dat, ('coo', idx), (n, m)) return inc
self._cache[key] = utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx))
return self._cache[key]
def to_networkx(self): def to_networkx(self):
"""Convert to networkx graph. """Convert to networkx graph.
......
...@@ -429,7 +429,7 @@ class ImmutableGraphIndex(object): ...@@ -429,7 +429,7 @@ class ImmutableGraphIndex(object):
return [ImmutableSubgraphIndex(gi, self, induced_n, return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)] induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]
def adjacency_matrix(self, transpose=False): def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination By default, a row of returned adjacency matrix represents the destination
...@@ -451,13 +451,7 @@ class ImmutableGraphIndex(object): ...@@ -451,13 +451,7 @@ class ImmutableGraphIndex(object):
def get_adj(ctx): def get_adj(ctx):
new_mat = self._sparse.adjacency_matrix(transpose) new_mat = self._sparse.adjacency_matrix(transpose)
return F.copy_to(new_mat, ctx) return F.copy_to(new_mat, ctx)
return self._sparse.adjacency_matrix(transpose, ctx)
if not transpose and 'in_adj' in self._cache:
return self._cache['in_adj']
elif transpose and 'out_adj' in self._cache:
return self._cache['out_adj']
else:
return utils.CtxCachedObject(lambda ctx: get_adj(ctx))
def incidence_matrix(self, oriented=False): def incidence_matrix(self, oriented=False):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
......
...@@ -276,10 +276,10 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -276,10 +276,10 @@ class UpdateAllExecutor(BasicExecutor):
if len(F.shape(dat)) > 1: if len(F.shape(dat)) > 1:
# The edge feature is of shape (N, 1) # The edge feature is of shape (N, 1)
dat = F.squeeze(dat, 1) dat = F.squeeze(dat, 1)
idx = F.sparse_matrix_indices(self.g.adjacency_matrix(ctx)) idx = F.sparse_matrix_indices(self.g.adjacency_matrix(ctx=ctx))
adjmat = F.sparse_matrix(dat, idx, self.graph_shape) adjmat = F.sparse_matrix(dat, idx, self.graph_shape)
else: else:
adjmat = self.g.adjacency_matrix(ctx) adjmat = self.g.adjacency_matrix(ctx=ctx)
return adjmat return adjmat
...@@ -347,8 +347,7 @@ class SendRecvExecutor(BasicExecutor): ...@@ -347,8 +347,7 @@ class SendRecvExecutor(BasicExecutor):
# edge feature is of shape (N, 1) # edge feature is of shape (N, 1)
dat = F.squeeze(dat, dim=1) dat = F.squeeze(dat, dim=1)
else: else:
# TODO(minjie): data type should be adjusted according t othe usage. dat = F.ones((len(self.u), ), dtype=F.float32, ctx=ctx)
dat = F.ones((len(self.u), ), dtype=F.float32)
adjmat = F.sparse_matrix(dat, ('coo', self.graph_idx), self.graph_shape) adjmat = F.sparse_matrix(dat, ('coo', self.graph_idx), self.graph_shape)
return F.copy_to(adjmat, ctx) return F.copy_to(adjmat, ctx)
......
...@@ -244,7 +244,7 @@ def build_relabel_map(x): ...@@ -244,7 +244,7 @@ def build_relabel_map(x):
x = x.tousertensor() x = x.tousertensor()
unique_x, _ = F.sort_1d(F.unique(x)) unique_x, _ = F.sort_1d(F.unique(x))
map_len = int(F.max(unique_x, dim=0)) + 1 map_len = int(F.max(unique_x, dim=0)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64) old_to_new = F.zeros(map_len, dtype=F.int64, ctx=F.cpu())
F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x))) F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x)))
return unique_x, old_to_new return unique_x, old_to_new
......
...@@ -14,8 +14,8 @@ def generate_rand_graph(n): ...@@ -14,8 +14,8 @@ def generate_rand_graph(n):
return g, ig return g, ig
def check_graph_equal(g1, g2): def check_graph_equal(g1, g2):
adj1 = g1.adjacency_matrix().get(mx.cpu()) != 0 adj1 = g1.adjacency_matrix(ctx=mx.cpu()) != 0
adj2 = g2.adjacency_matrix().get(mx.cpu()) != 0 adj2 = g2.adjacency_matrix(ctx=mx.cpu()) != 0
assert mx.nd.sum(adj1 - adj2).asnumpy() == 0 assert mx.nd.sum(adj1 - adj2).asnumpy() == 0
def test_graph_gen(): def test_graph_gen():
......
...@@ -40,8 +40,8 @@ def generate_graph(grad=False): ...@@ -40,8 +40,8 @@ def generate_graph(grad=False):
ecol = Variable(th.randn(17, D), requires_grad=grad) ecol = Variable(th.randn(17, D), requires_grad=grad)
g.ndata['h'] = ncol g.ndata['h'] = ncol
g.edata['w'] = ecol g.edata['w'] = ecol
g.set_n_initializer(lambda shape, dtype : th.zeros(shape)) g.set_n_initializer(lambda shape, dtype, ctx : th.zeros(shape, dtype=dtype, device=ctx))
g.set_e_initializer(lambda shape, dtype : th.zeros(shape)) g.set_e_initializer(lambda shape, dtype, ctx : th.zeros(shape, dtype=dtype, device=ctx))
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
......
import time
import math
import numpy as np
import scipy.sparse as sp
import torch as th
import dgl
def test_adjmat_speed():
n = 1000
p = 10 * math.log(n) / n
a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
g = dgl.DGLGraph(a)
# the first call should contruct the adj
t0 = time.time()
g.adjacency_matrix()
dur1 = time.time() - t0
# the second call should be cached and should be very fast
t0 = time.time()
g.adjacency_matrix()
dur2 = time.time() - t0
assert dur2 < dur1 / 5
def test_incmat_speed():
n = 1000
p = 10 * math.log(n) / n
a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
g = dgl.DGLGraph(a)
# the first call should contruct the adj
t0 = time.time()
g.incidence_matrix()
dur1 = time.time() - t0
# the second call should be cached and should be very fast
t0 = time.time()
g.incidence_matrix()
dur2 = time.time() - t0
assert dur2 < dur1
if __name__ == '__main__':
test_adjmat_speed()
test_incmat_speed()
import torch as th import torch as th
import dgl import dgl
import utils as U
def test_simple_readout(): def test_simple_readout():
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
...@@ -29,26 +30,26 @@ def test_simple_readout(): ...@@ -29,26 +30,26 @@ def test_simple_readout():
g2.ndata['w'] = w2 g2.ndata['w'] = w2
g1.edata['x'] = e1 g1.edata['x'] = e1
assert th.allclose(dgl.sum_nodes(g1, 'x'), s1) assert U.allclose(dgl.sum_nodes(g1, 'x'), s1)
assert th.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1)
assert th.allclose(dgl.sum_edges(g1, 'x'), se1) assert U.allclose(dgl.sum_edges(g1, 'x'), se1)
assert th.allclose(dgl.mean_nodes(g1, 'x'), m1) assert U.allclose(dgl.mean_nodes(g1, 'x'), m1)
assert th.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
assert th.allclose(dgl.mean_edges(g1, 'x'), me1) assert U.allclose(dgl.mean_edges(g1, 'x'), me1)
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
s = dgl.sum_nodes(g, 'x') s = dgl.sum_nodes(g, 'x')
m = dgl.mean_nodes(g, 'x') m = dgl.mean_nodes(g, 'x')
assert th.allclose(s, th.stack([s1, s2], 0)) assert U.allclose(s, th.stack([s1, s2], 0))
assert th.allclose(m, th.stack([m1, m2], 0)) assert U.allclose(m, th.stack([m1, m2], 0))
ws = dgl.sum_nodes(g, 'x', 'w') ws = dgl.sum_nodes(g, 'x', 'w')
wm = dgl.mean_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w')
assert th.allclose(ws, th.stack([ws1, ws2], 0)) assert U.allclose(ws, th.stack([ws1, ws2], 0))
assert th.allclose(wm, th.stack([wm1, wm2], 0)) assert U.allclose(wm, th.stack([wm1, wm2], 0))
s = dgl.sum_edges(g, 'x') s = dgl.sum_edges(g, 'x')
m = dgl.mean_edges(g, 'x') m = dgl.mean_edges(g, 'x')
assert th.allclose(s, th.stack([se1, th.zeros(5)], 0)) assert U.allclose(s, th.stack([se1, th.zeros(5)], 0))
assert th.allclose(m, th.stack([me1, th.zeros(5)], 0)) assert U.allclose(m, th.stack([me1, th.zeros(5)], 0))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment