Unverified Commit 68ec6247 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API][Doc] API change & basic tutorials (#113)

* Add SH tutorials

* setup sphinx-gallery; work on graph tutorial

* draft dglgraph tutorial

* update readme to include document url

* rm obsolete file

* Draft the message passing tutorial

* Capsule code (#102)

* add capsule example

* clean code

* better naming

* better naming

* [GCN]tutorial scaffold

* fix capsule example code

* remove previous capsule example code

* graph struc edit

* modified:   2_graph.py

* update doc of capsule

* update capsule docs

* update capsule docs

* add msg passing prime

* GCN-GAT tutorial Section 1 and 2

* comment for API improvement

* section 3

* Tutorial API change (#115)

* change the API as discusses; toy example

* enable the new set/get syntax

* fixed pytorch utest

* fixed gcn example

* fixed gat example

* fixed mx utests

* fix mx utest

* delete apply edges; add utest for update_edges

* small change on toy example

* fix utest

* fix out in degrees bug

* update pagerank example and add it to CI

* add delitem for dataview

* make edges() return form that is compatible with send/update_edges etc

* fix index bug when the given data is one-int-tensor

* fix doc
parent 2ecd2b23
......@@ -17,7 +17,7 @@ def build_dgl() {
}
dir ('build') {
sh 'cmake ..'
sh 'make -j$(nproc)'
sh 'make -j4'
}
}
......
......@@ -46,6 +46,7 @@ extensions = [
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery',
]
......
from __future__ import division
import networkx as nx
from dgl.graph import DGLGraph
DAMP = 0.85
N = 100
K = 10
def message_func(src, dst, edge):
return src['pv'] / src['deg']
def update_func(node, accum):
pv = (1 - DAMP) / N + DAMP * accum
return {'pv' : pv}
def compute_pagerank(g):
g = DGLGraph(g)
print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# init pv value
for n in g.nodes():
g.node[n]['pv'] = 1 / N
g.node[n]['deg'] = g.out_degree(n)
# pagerank
for k in range(K):
g.update_all()
return [g.node[n]['pv'] for n in g.nodes()]
if __name__ == '__main__':
g = nx.erdos_renyi_graph(N, 0.05)
pv = compute_pagerank(g)
print(pv)
......@@ -16,18 +16,18 @@ import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
class GATReduce(nn.Module):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
def forward(self, node, msgs):
a1 = torch.unsqueeze(node['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D)
def forward(self, nodes):
a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1)
......@@ -46,13 +46,13 @@ class GATFinalize(nn.Module):
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node):
ret = node['accum']
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
ret = self.residual_fc(nodes.data['h']) + ret
else:
ret = node['h'] + ret
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(nn.Module):
......@@ -120,7 +120,7 @@ class GAT(nn.Module):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
self.g.set_n_repr(self.prp[i](last))
self.g.ndata.update(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
......@@ -128,7 +128,7 @@ class GAT(nn.Module):
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
# output projection
self.g.set_n_repr(self.prp[-1](last))
self.g.ndata.update(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')
......
......@@ -14,24 +14,22 @@ import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return {'m' : src['h']}
def gcn_msg(edges):
return {'m' : edges.src['h']}
def gcn_reduce(node, msgs):
return {'h' : torch.sum(msgs['m'], 1)}
def gcn_reduce(nodes):
return {'h' : torch.sum(nodes.mailbox['m'], 1)}
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node['h'])
def forward(self, nodes):
h = self.linear(nodes.data['h'])
if self.activation:
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
......@@ -62,13 +60,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr({'h' : features})
self.g.ndata['h'] = features
for layer in self.layers:
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h')
......
......@@ -22,8 +22,8 @@ class NodeApplyModule(nn.Module):
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node['h'])
def forward(self, nodes):
h = self.linear(nodes.data['h'])
if self.activation:
h = self.activation(h)
......@@ -57,13 +57,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr({'h' : features})
self.g.ndata['h'] = features
for layer in self.layers:
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
......
import networkx as nx
import torch
import dgl
import dgl.function as fn
N = 100
g = nx.nx.erdos_renyi_graph(N, 0.05)
g = dgl.DGLGraph(g)
DAMP = 0.85
K = 10
def compute_pagerank(g):
g.ndata['pv'] = torch.ones(N) / N
degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees
g.update_all(message_func=fn.copy_src(src='pv', out='m'),
reduce_func=fn.sum(msg='m', out='pv'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
return g.ndata['pv']
pv = compute_pagerank(g)
print(pv)
......@@ -11,3 +11,4 @@ from .base import ALL
from .batched_graph import *
from .graph import DGLGraph
from .subgraph import DGLSubGraph
from .udf import NodeBatch, EdgeBatch
......@@ -10,7 +10,7 @@ __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object):
"""Base builtin message function class."""
def __call__(self, src, edge):
def __call__(self, edges):
"""Regular computation of this builtin.
This will be used when optimization is not available.
......@@ -38,14 +38,10 @@ class BundledMessageFunction(MessageFunction):
return False
return True
def __call__(self, src, edge):
ret = None
def __call__(self, edges):
ret = dict()
for fn in self.fn_list:
msg = fn(src, edge)
if ret is None:
ret = msg
else:
# ret and msg must be dict
msg = fn(edges)
ret.update(msg)
return ret
......@@ -83,8 +79,9 @@ class SrcMulEdgeMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge):
ret = self.mul_op(src[self.src_field], edge[self.edge_field])
def __call__(self, edges):
ret = self.mul_op(edges.src[self.src_field],
edges.data[self.edge_field])
return {self.out_field : ret}
def name(self):
......@@ -98,8 +95,8 @@ class CopySrcMessageFunction(MessageFunction):
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge):
return {self.out_field : src[self.src_field]}
def __call__(self, edges):
return {self.out_field : edges.src[self.src_field]}
def name(self):
return "copy_src"
......@@ -114,15 +111,8 @@ class CopyEdgeMessageFunction(MessageFunction):
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge):
if self.edge_field is not None:
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def __call__(self, edges):
return {self.out_field : edges.data[self.edge_field]}
def name(self):
return "copy_edge"
......
......@@ -8,7 +8,7 @@ __all__ = ["sum", "max"]
class ReduceFunction(object):
"""Base builtin reduce function class."""
def __call__(self, node, msgs):
def __call__(self, nodes):
"""Regular computation of this builtin.
This will be used when optimization is not available.
......@@ -35,14 +35,10 @@ class BundledReduceFunction(ReduceFunction):
return False
return True
def __call__(self, node, msgs):
ret = None
def __call__(self, nodes):
ret = dict()
for fn in self.fn_list:
rpr = fn(node, msgs)
if ret is None:
ret = rpr
else:
# ret and rpr must be dict
rpr = fn(nodes)
ret.update(rpr)
return ret
......@@ -60,8 +56,8 @@ class ReducerFunctionTemplate(ReduceFunction):
# NOTE: only sum is supported right now.
return self.name == "sum"
def __call__(self, node, msgs):
return {self.out_field : self.op(msgs[self.msg_field], 1)}
def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}
def name(self):
return self.name
......
This diff is collapsed.
......@@ -5,10 +5,11 @@ import numpy as np
from .base import ALL, DGLError
from . import backend as F
from collections import defaultdict as ddict
from .function import message as fmsg
from .function import reducer as fred
from .udf import NodeBatch, EdgeBatch
from . import utils
from collections import defaultdict as ddict
from ._ffi.function import _init_api
......@@ -176,7 +177,7 @@ class DegreeBucketingExecutor(Executor):
# loop over each bucket
# FIXME (lingfan): handle zero-degree case
for deg, vv, msg_id in zip(self.degrees, self.dsts, self.msg_ids):
dst_reprs = self.g.get_n_repr(vv)
v_data = self.g.get_n_repr(vv)
in_msgs = self.msg_frame.select_rows(msg_id)
def _reshape_fn(msg):
msg_shape = F.shape(msg)
......@@ -184,7 +185,8 @@ class DegreeBucketingExecutor(Executor):
return F.reshape(msg, new_shape)
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs))
nb = NodeBatch(self.g, vv, v_data, reshaped_in_msgs)
new_reprs.append(self.rfunc(nb))
# Pack all reducer results together
keys = new_reprs[0].keys()
......@@ -320,7 +322,7 @@ class SendRecvExecutor(BasicExecutor):
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v)
self._edge_repr = self.g.get_e_repr((self.u, self.v))
return self._edge_repr
def _build_adjmat(self):
......@@ -432,8 +434,11 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
if (isinstance(mfunc, fmsg.BundledMessageFunction)
or isinstance(rfunc, fred.BundledReduceFunction)):
if not isinstance(mfunc, fmsg.BundledMessageFunction):
mfunc = fmsg.BundledMessageFunction(mfunc)
if not isinstance(rfunc, fred.BundledReduceFunction):
rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledSendRecvExecutor
else:
......
"""User-defined function related data structures."""
from __future__ import absolute_import
from collections import Mapping
from .base import ALL, is_all
from . import backend as F
from . import utils
class EdgeBatch(object):
"""The object that represents a batch of edges.
Parameters
----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL
src_data : dict of tensors
The src node features
edge_data : dict of tensors
The edge features.
dst_data : dict of tensors
The dst node features
"""
def __init__(self, g, edges, src_data, edge_data, dst_data):
self._g = g
self._edges = edges
self._src_data = src_data
self._edge_data = edge_data
self._dst_data = dst_data
@property
def src(self):
"""Return the feature data of the source nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._src_data
@property
def dst(self):
"""Return the feature data of the destination nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._dst_data
@property
def data(self):
"""Return the edge feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._edge_data
def edges(self):
"""Return the edges contained in this batch.
Returns
-------
tuple of tensors
The edge tuple (u, v, eid).
"""
if is_all(self._edges[2]):
self._edges[2] = utils.toindex(F.arange(
0, self._g.number_of_edges(), dtype=F.int64))
u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
def batch_size(self):
"""Return the number of edges in this edge batch."""
return len(self._edges[0])
def __len__(self):
"""Return the number of edges in this edge batch."""
return self.batch_size()
class NodeBatch(object):
"""The object that represents a batch of nodes.
Parameters
----------
g : DGLGraph
The graph object.
nodes : utils.Index or ALL
The node ids.
data : dict of tensors
The node features
msgs : dict of tensors, optional
The messages.
"""
def __init__(self, g, nodes, data, msgs=None):
self._g = g
self._nodes = nodes
self._data = data
self._msgs = msgs
@property
def data(self):
"""Return the node feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._data
@property
def mailbox(self):
"""Return the received messages.
If no messages received, a None will be returned.
Returns
-------
dict of str to tensors
The message data.
"""
return self._msgs
def nodes(self):
"""Return the nodes contained in this batch.
Returns
-------
tensor
The nodes.
"""
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes(), dtype=F.int64))
return self._nodes.tousertensor()
def batch_size(self):
"""Return the number of nodes in this node batch."""
if is_all(self._nodes):
return self._g.number_of_nodes()
else:
return len(self._nodes)
def __len__(self):
"""Return the number of nodes in this node batch."""
return self.batch_size()
......@@ -20,8 +20,14 @@ class Index(object):
def _dispatch(self, data):
"""Store data based on its type."""
if isinstance(data, Tensor):
if not (F.dtype(data) == F.int64 and len(F.shape(data)) == 1):
if not (F.dtype(data) == F.int64):
raise ValueError('Index data must be an int64 vector, but got: %s' % str(data))
if len(F.shape(data)) > 1:
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
if len(F.shape(data)) == 0:
# a tensor of one int
self._dispatch(int(data))
else:
self._user_tensor_data[F.get_context(data)] = data
elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1):
......@@ -343,3 +349,18 @@ def reorder(dict_like, index):
idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict
def parse_edges_tuple(edges):
"""Parse the given edges and return the tuple.
Parameters
----------
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns
-------
A tuple of (u, v, eid)
"""
pass
"""Views of DGLGraph."""
from __future__ import absolute_import
from collections import MutableMapping, namedtuple
from .base import ALL, is_all, DGLError
from . import backend as F
from . import utils
NodeSpace = namedtuple('NodeSpace', ['data'])
class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph.
Compared with networkx's NodeView, DGL's NodeView is not
a map. DGL's NodeView supports creating data view from
a given list of nodes.
Parameters
----------
graph : DGLGraph
The graph.
Examples
--------
TBD
"""
__slot__ = '_graph'
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_nodes()
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=NodeDataView(self._graph, ALL))
else:
return NodeSpace(data=NodeDataView(self._graph, nodes))
def __call__(self):
"""Return the nodes."""
return F.arange(0, len(self))
class NodeDataView(MutableMapping):
__slot__ = ['_graph', '_nodes']
def __init__(self, graph, nodes):
self._graph = graph
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr({key : val}, self._nodes)
def __delitem__(self, key):
if not is_all(self._nodes):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
self._graph.pop_n_repr(key)
def __len__(self):
return len(self._graph._node_frame)
def __iter__(self):
return iter(self._graph._node_frame)
def __repr__(self):
data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object):
__slot__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_edges()
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=EdgeDataView(self._graph, ALL))
else:
return EdgeSpace(data=EdgeDataView(self._graph, edges))
def __call__(self, *args, **kwargs):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
class EdgeDataView(MutableMapping):
__slot__ = ['_graph', '_edges']
def __init__(self, graph, edges):
self._graph = graph
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr({key : val}, self._edges)
def __delitem__(self, key):
if not is_all(self._edges):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.edata[key]` instead.')
self._graph.pop_e_repr(key)
def __len__(self):
return len(self._graph._edge_frame)
def __iter__(self):
return iter(self._graph._edge_frame)
def __repr__(self):
data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame})
......@@ -11,20 +11,20 @@ def check_eq(a, b):
assert a.shape == b.shape
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def message_func(src, edge):
assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D
return {'m' : src['h']}
def message_func(edges):
assert len(edges.src['h'].shape) == 2
assert edges.src['h'].shape[1] == D
return {'m' : edges.src['h']}
def reduce_func(node, msgs):
msgs = msgs['m']
def reduce_func(nodes):
msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'m' : mx.nd.sum(msgs, 1)}
def apply_node_func(node):
return {'h' : node['h'] + node['m']}
def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['m']}
def generate_graph(grad=False):
g = DGLGraph()
......@@ -38,7 +38,7 @@ def generate_graph(grad=False):
ncol = mx.nd.random.normal(shape=(10, D))
if grad:
ncol.attach_grad()
g.set_n_repr({'h' : ncol})
g.ndata['h'] = ncol
return g
def test_batch_setter_getter():
......@@ -47,15 +47,15 @@ def test_batch_setter_getter():
g = generate_graph()
# set all nodes
g.set_n_repr({'h' : mx.nd.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
assert _pfc(g.ndata['h']) == [0.] * 10
# pop nodes
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0
assert len(g.ndata) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes
u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64')
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
......@@ -81,77 +81,77 @@ def test_batch_setter_getter():
9, 0, 16
'''
# set all edges
g.set_e_repr({'l' : mx.nd.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
g.edata['l'] = mx.nd.zeros((17, D))
assert _pfc(g.edata['l']) == [0.] * 17
# pop edges
assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0
g.set_e_repr({'l' : mx.nd.zeros((17, D))})
assert len(g.edata) == 0
g.edata['l'] = mx.nd.zeros((17, D))
# set partial edges (many-many)
u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((5, D))}, u, v)
g.edges[u, v].data['l'] = mx.nd.ones((5, D))
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (many-one)
u = mx.nd.array([3, 4, 6], dtype='int64')
v = mx.nd.array([9], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (one-many)
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([4, 5, 6], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# get partial edges (many-many)
u = mx.nd.array([0, 6, 0], dtype='int64')
v = mx.nd.array([6, 9, 7], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one)
u = mx.nd.array([5, 6, 7], dtype='int64')
v = mx.nd.array([9], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many)
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([3, 4, 5], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
with mx.autograd.record():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
h1 = g.ndata['h']
h1.attach_grad()
# partial set
v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D))
hh.attach_grad()
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
h2 = g.ndata['h']
h2.backward(mx.nd.ones((10, D)) * 2)
check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
def _fmsg(edges):
assert edges.src['h'].shape == (5, D)
return {'m' : edges.src['h']}
g.register_message_func(_fmsg)
# many-many send
u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v)
g.send((u, v))
# one-many send
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v)
g.send((u, v))
# many-one send
u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
v = mx.nd.array([9], dtype='int64')
g.send(u, v)
g.send((u, v))
def test_batch_recv():
# basic recv test
......@@ -162,7 +162,7 @@ def test_batch_recv():
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
reduce_msg_shapes.clear()
g.send(u, v)
g.send((u, v))
#g.recv(th.unique(v))
#assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
#reduce_msg_shapes.clear()
......@@ -177,7 +177,7 @@ def test_update_routines():
reduce_msg_shapes.clear()
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
g.send_and_recv(u, v)
g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
......@@ -208,14 +208,14 @@ def test_reduce_0deg():
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : node['h'] + msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(5, 5))
g.set_n_repr({'h': old_repr})
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
......@@ -224,25 +224,25 @@ def test_pull_0deg():
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr})
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
......
......@@ -10,20 +10,20 @@ def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(src, edge):
assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D
return {'m' : src['h']}
def message_func(edges):
assert len(edges.src['h'].shape) == 2
assert edges.src['h'].shape[1] == D
return {'m' : edges.src['h']}
def reduce_func(node, msgs):
msgs = msgs['m']
def reduce_func(nodes):
msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node):
return {'h' : node['h'] + node['accum']}
def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['accum']}
def generate_graph(grad=False):
g = DGLGraph()
......@@ -36,10 +36,11 @@ def generate_graph(grad=False):
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol})
g.ndata['h'] = ncol
g.edata['w'] = ecol
g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
g.set_e_initializer(lambda shape, dtype : th.zeros(shape))
return g
def test_batch_setter_getter():
......@@ -47,20 +48,20 @@ def test_batch_setter_getter():
return list(x.numpy()[:,0])
g = generate_graph()
# set all nodes
g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
g.ndata['h'] = th.zeros((10, D))
assert th.allclose(g.ndata['h'], th.zeros((10, D)))
# pop nodes
old_len = len(g.get_n_repr())
old_len = len(g.ndata)
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))})
assert len(g.ndata) == old_len - 1
g.ndata['h'] = th.zeros((10, D))
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr({'h' : th.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
g.nodes[u].data['h'] = th.ones((3, D))
assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
'''
s, d, eid
......@@ -83,75 +84,75 @@ def test_batch_setter_getter():
9, 0, 16
'''
# set all edges
g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
g.edata['l'] = th.zeros((17, D))
assert _pfc(g.edata['l']) == [0.] * 17
# pop edges
old_len = len(g.get_e_repr())
old_len = len(g.edata)
assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))})
assert len(g.edata) == old_len - 1
g.edata['l'] = th.zeros((17, D))
# set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr({'l' : th.ones((5, D))}, u, v)
g.edges[u, v].data['l'] = th.ones((5, D))
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (many-one)
u = th.tensor([3, 4, 6])
v = th.tensor([9])
g.set_e_repr({'l' : th.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = th.ones((3, D))
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (one-many)
u = th.tensor([0])
v = th.tensor([4, 5, 6])
g.set_e_repr({'l' : th.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = th.ones((3, D))
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# get partial edges (many-many)
u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one)
u = th.tensor([5, 6, 7])
v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many)
u = th.tensor([0])
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
h1 = g.ndata['h']
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
g.nodes[v].data['h'] = hh
h2 = g.ndata['h']
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
def _fmsg(edges):
assert edges.src['h'].shape == (5, D)
return {'m' : edges.src['h']}
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
g.send((u, v))
# one-many send
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
g.send((u, v))
# many-one send
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.send(u, v)
g.send((u, v))
def test_batch_recv():
# basic recv test
......@@ -162,11 +163,25 @@ def test_batch_recv():
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.send(u, v)
g.send((u, v))
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_edges():
def _upd(edges):
return {'w' : edges.data['w'] * 2}
g = generate_graph()
g.register_edge_func(_upd)
old = g.edata['w']
g.update_edges()
assert th.allclose(old * 2, g.edata['w'])
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_edges((u, v), lambda edges : {'w' : edges.data['w'] * 0.})
eid = g.edge_ids(u, v)
assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func)
......@@ -177,7 +192,7 @@ def test_update_routines():
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.send_and_recv(u, v)
g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
......@@ -208,14 +223,14 @@ def test_reduce_0deg():
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : node['h'] + msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = th.randn(5, 5)
g.set_n_repr({'h' : old_repr})
g.ndata['h'] = old_repr
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0))
......@@ -224,26 +239,26 @@ def test_pull_0deg():
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr})
g.ndata['h'] = old_repr
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr})
g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
......@@ -253,27 +268,26 @@ def _disabled_test_send_twice():
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(src, edge):
return {'a': src['a']}
def _message_b(src, edge):
return {'a': src['a'] * 3}
def _reduce(node, msgs):
assert msgs is not None
return {'a': msgs['a'].max(1)[0]}
def _message_a(edges):
return {'a': edges.src['a']}
def _message_b(edges):
return {'a': edges.src['a'] * 3}
def _reduce(nodes):
return {'a': nodes.mailbox['a'].max(1)[0]}
old_repr = th.randn(3, 5)
g.set_n_repr({'a': old_repr})
g.send(0, 1, _message_a)
g.send(0, 1, _message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((0, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr[0] * 3)
g.set_n_repr({'a': old_repr})
g.send(0, 1, _message_a)
g.send(2, 1, _message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((2, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_multigraph():
......@@ -284,64 +298,63 @@ def test_send_multigraph():
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(src, edge):
return {'a': edge['a']}
def _message_b(src, edge):
return {'a': edge['a'] * 3}
def _reduce(node, msgs):
assert msgs is not None
return {'a': msgs['a'].max(1)[0]}
def _message_a(edges):
return {'a': edges.data['a']}
def _message_b(edges):
return {'a': edges.data['a'] * 3}
def _reduce(nodes):
return {'a': nodes.mailbox['a'].max(1)[0]}
def answer(*args):
return th.stack(args, 0).max(0)[0]
# send by eid
old_repr = th.randn(4, 5)
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=[0, 2], message_func=_message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send([0, 2], message_func=_message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=[0, 2, 3], message_func=_message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send([0, 2, 3], message_func=_message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
# send on multigraph
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send([0, 2], [1, 1], _message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send(([0, 2], [1, 1]), _message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr.max(0)[0])
# consecutive send and send_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(2, 1, _message_a)
g.send(eid=[0, 1], message_func=_message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send((2, 1), _message_a)
g.send([0, 1], message_func=_message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
# consecutive send_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=0, message_func=_message_a)
g.send(eid=1, message_func=_message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send(0, message_func=_message_a)
g.send(1, message_func=_message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
# send_and_recv_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send_and_recv(eid=[0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))
......@@ -353,29 +366,25 @@ def test_dynamic_addition():
# Test node addition
g.add_nodes(N)
g.set_n_repr({'h1': th.randn(N, D),
g.ndata.update({'h1': th.randn(N, D),
'h2': th.randn(N, D)})
g.add_nodes(3)
n_repr = g.get_n_repr()
assert n_repr['h1'].shape[0] == n_repr['h2'].shape[0] == N + 3
assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
# Test edge addition
g.add_edge(0, 1)
g.add_edge(1, 0)
g.set_e_repr({'h1': th.randn(2, D),
g.edata.update({'h1': th.randn(2, D),
'h2': th.randn(2, D)})
e_repr = g.get_e_repr()
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 2
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
g.add_edges([0, 2], [2, 0])
e_repr = g.get_e_repr()
g.set_e_repr({'h1': th.randn(4, D)})
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 4
g.edata['h1'] = th.randn(4, D)
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
g.add_edge(1, 2)
g.set_e_repr_by_id({'h1': th.randn(1, D)}, eid=4)
e_repr = g.get_e_repr()
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 5
g.edges[4].data['h1'] = th.randn(1, D)
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
if __name__ == '__main__':
......@@ -383,6 +392,7 @@ if __name__ == '__main__':
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_edges()
test_update_routines()
test_reduce_0deg()
test_pull_0deg()
......
......@@ -18,8 +18,8 @@ def tree1():
g.add_edge(4, 1)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr({'h' : th.randn(4, 10)})
g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.edata['h'] = th.randn(4, 10)
return g
def tree2():
......@@ -37,17 +37,17 @@ def tree2():
g.add_edge(0, 4)
g.add_edge(4, 1)
g.add_edge(3, 1)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr({'h' : th.randn(4, 10)})
g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.edata['h'] = th.randn(4, 10)
return g
def test_batch_unbatch():
t1 = tree1()
t2 = tree2()
n1 = t1.get_n_repr()['h']
n2 = t2.get_n_repr()['h']
e1 = t1.get_e_repr()['h']
e2 = t2.get_e_repr()['h']
n1 = t1.ndata['h']
n2 = t2.ndata['h']
e1 = t1.edata['h']
e2 = t2.edata['h']
bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10
......@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr()['h'], tt1.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr()['h'], tt1.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr()['h'], tt2.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr()['h'], tt2.get_e_repr()['h'])
assert th.allclose(t1.ndata['h'], tt1.ndata['h'])
assert th.allclose(t1.edata['h'], tt1.edata['h'])
assert th.allclose(t2.ndata['h'], tt2.ndata['h'])
assert th.allclose(t2.edata['h'], tt2.edata['h'])
def test_batch_unbatch1():
t1 = tree1()
......@@ -74,29 +74,29 @@ def test_batch_unbatch1():
assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr()['h'], s1.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr()['h'], s1.get_e_repr()['h'])
assert th.allclose(t1.get_n_repr()['h'], s2.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr()['h'], s2.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr()['h'], s3.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr()['h'], s3.get_e_repr()['h'])
assert th.allclose(t2.ndata['h'], s1.ndata['h'])
assert th.allclose(t2.edata['h'], s1.edata['h'])
assert th.allclose(t1.ndata['h'], s2.ndata['h'])
assert th.allclose(t1.edata['h'], s2.edata['h'])
assert th.allclose(t2.ndata['h'], s3.ndata['h'])
assert th.allclose(t2.edata['h'], s3.edata['h'])
def test_batch_sendrecv():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
u = [3, 4, 2 + 5, 0 + 5]
v = [1, 1, 4 + 5, 4 + 5]
bg.send(u, v)
bg.send((u, v))
bg.recv(v)
t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][1] == 7
assert t2.get_n_repr()['h'][4] == 2
assert t1.ndata['h'][1] == 7
assert t2.ndata['h'][4] == 2
def test_batch_propagate():
......@@ -104,8 +104,8 @@ def test_batch_propagate():
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
# get leaves.
order = []
......@@ -123,23 +123,23 @@ def test_batch_propagate():
bg.propagate(traverser=order)
t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][0] == 9
assert t2.get_n_repr()['h'][1] == 5
assert t1.ndata['h'][0] == 9
assert t2.ndata['h'][1] == 5
def test_batched_edge_ordering():
g1 = dgl.DGLGraph()
g1.add_nodes(6)
g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
e1 = th.randn(5, 10)
g1.set_e_repr({'h' : e1})
g1.edata['h'] = e1
g2 = dgl.DGLGraph()
g2.add_nodes(6)
g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
e2 = th.randn(6, 10)
g2.set_e_repr({'h' : e2})
g2.edata['h'] = e2
g = dgl.batch([g1, g2])
r1 = g.get_e_repr()['h'][g.edge_id(4, 5)]
r2 = g1.get_e_repr()['h'][g1.edge_id(4, 5)]
r1 = g.edata['h'][g.edge_id(4, 5)]
r2 = g1.edata['h'][g1.edge_id(4, 5)]
assert th.equal(r1, r2)
def test_batch_no_edge():
......
......@@ -12,8 +12,8 @@ def test_filter():
n_repr[[1, 3]] = 1
e_repr[[1, 3]] = 1
g.set_n_repr({'a': n_repr})
g.set_e_repr({'a': e_repr})
g.ndata['a'] = n_repr
g.edata['a'] = e_repr
def predicate(r):
return r['a'].max(1)[0] > 0
......
......@@ -6,7 +6,7 @@ def generate_graph():
g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float)
g.set_n_repr({'h': h})
g.ndata['h'] = h
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
......@@ -15,29 +15,11 @@ def generate_graph():
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr({'h' : h})
g.edata['h'] = h
return g
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float)
h = th.arange(1, 11, dtype=th.float)
g.set_n_repr(h)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr(h)
return g
def reducer_both(node, msgs):
return {'h' : th.sum(msgs['m'], 1)}
def reducer_both(nodes):
return {'h' : th.sum(nodes.mailbox['m'], 1)}
def test_copy_src():
# copy_src with both fields
......@@ -45,7 +27,7 @@ def test_copy_src():
g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge():
......@@ -54,7 +36,7 @@ def test_copy_edge():
g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge():
......@@ -63,7 +45,7 @@ def test_src_mul_edge():
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
assert th.allclose(g.ndata['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment