"examples/vscode:/vscode.git/clone" did not exist on "efae0f97d8b18867398981d226423472a3069844"
Unverified Commit 68ec6247 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API][Doc] API change & basic tutorials (#113)

* Add SH tutorials

* setup sphinx-gallery; work on graph tutorial

* draft dglgraph tutorial

* update readme to include document url

* rm obsolete file

* Draft the message passing tutorial

* Capsule code (#102)

* add capsule example

* clean code

* better naming

* better naming

* [GCN]tutorial scaffold

* fix capsule example code

* remove previous capsule example code

* graph struc edit

* modified:   2_graph.py

* update doc of capsule

* update capsule docs

* update capsule docs

* add msg passing prime

* GCN-GAT tutorial Section 1 and 2

* comment for API improvement

* section 3

* Tutorial API change (#115)

* change the API as discusses; toy example

* enable the new set/get syntax

* fixed pytorch utest

* fixed gcn example

* fixed gat example

* fixed mx utests

* fix mx utest

* delete apply edges; add utest for update_edges

* small change on toy example

* fix utest

* fix out in degrees bug

* update pagerank example and add it to CI

* add delitem for dataview

* make edges() return form that is compatible with send/update_edges etc

* fix index bug when the given data is one-int-tensor

* fix doc
parent 2ecd2b23
...@@ -17,7 +17,7 @@ def build_dgl() { ...@@ -17,7 +17,7 @@ def build_dgl() {
} }
dir ('build') { dir ('build') {
sh 'cmake ..' sh 'cmake ..'
sh 'make -j$(nproc)' sh 'make -j4'
} }
} }
......
...@@ -46,6 +46,7 @@ extensions = [ ...@@ -46,6 +46,7 @@ extensions = [
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery', 'sphinx_gallery.gen_gallery',
] ]
......
from __future__ import division
import networkx as nx
from dgl.graph import DGLGraph
DAMP = 0.85
N = 100
K = 10
def message_func(src, dst, edge):
return src['pv'] / src['deg']
def update_func(node, accum):
pv = (1 - DAMP) / N + DAMP * accum
return {'pv' : pv}
def compute_pagerank(g):
g = DGLGraph(g)
print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# init pv value
for n in g.nodes():
g.node[n]['pv'] = 1 / N
g.node[n]['deg'] = g.out_degree(n)
# pagerank
for k in range(K):
g.update_all()
return [g.node[n]['pv'] for n in g.nodes()]
if __name__ == '__main__':
g = nx.erdos_renyi_graph(N, 0.05)
pv = compute_pagerank(g)
print(pv)
...@@ -16,18 +16,18 @@ import dgl ...@@ -16,18 +16,18 @@ import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gat_message(src, edge): def gat_message(edges):
return {'ft' : src['ft'], 'a2' : src['a2']} return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
class GATReduce(nn.Module): class GATReduce(nn.Module):
def __init__(self, attn_drop): def __init__(self, attn_drop):
super(GATReduce, self).__init__() super(GATReduce, self).__init__()
self.attn_drop = attn_drop self.attn_drop = attn_drop
def forward(self, node, msgs): def forward(self, nodes):
a1 = torch.unsqueeze(node['a1'], 1) # shape (B, 1, 1) a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1) a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D) ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention # attention
a = a1 + a2 # shape (B, deg, 1) a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1) e = F.softmax(F.leaky_relu(a), dim=1)
...@@ -46,13 +46,13 @@ class GATFinalize(nn.Module): ...@@ -46,13 +46,13 @@ class GATFinalize(nn.Module):
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim) self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node): def forward(self, nodes):
ret = node['accum'] ret = nodes.data['accum']
if self.residual: if self.residual:
if self.residual_fc is not None: if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret ret = self.residual_fc(nodes.data['h']) + ret
else: else:
ret = node['h'] + ret ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)} return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(nn.Module): class GATPrepare(nn.Module):
...@@ -120,7 +120,7 @@ class GAT(nn.Module): ...@@ -120,7 +120,7 @@ class GAT(nn.Module):
for hid in range(self.num_heads): for hid in range(self.num_heads):
i = l * self.num_heads + hid i = l * self.num_heads + hid
# prepare # prepare
self.g.set_n_repr(self.prp[i](last)) self.g.ndata.update(self.prp[i](last))
# message passing # message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i]) self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads # merge all the heads
...@@ -128,7 +128,7 @@ class GAT(nn.Module): ...@@ -128,7 +128,7 @@ class GAT(nn.Module):
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)], [self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1) dim=1)
# output projection # output projection
self.g.set_n_repr(self.prp[-1](last)) self.g.ndata.update(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1]) self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0') return self.g.pop_n_repr('head0')
......
...@@ -14,24 +14,22 @@ import torch.nn.functional as F ...@@ -14,24 +14,22 @@ import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(src, edge): def gcn_msg(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def gcn_reduce(node, msgs): def gcn_reduce(nodes):
return {'h' : torch.sum(msgs['m'], 1)} return {'h' : torch.sum(nodes.mailbox['m'], 1)}
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, nodes):
h = self.linear(node['h']) h = self.linear(nodes.data['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h' : h} return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
...@@ -62,13 +60,13 @@ class GCN(nn.Module): ...@@ -62,13 +60,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr({'h' : features}) self.g.ndata['h'] = features
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
self.g.apply_nodes(apply_node_func= self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])}) lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h') return self.g.pop_n_repr('h')
......
...@@ -22,8 +22,8 @@ class NodeApplyModule(nn.Module): ...@@ -22,8 +22,8 @@ class NodeApplyModule(nn.Module):
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, nodes):
h = self.linear(node['h']) h = self.linear(nodes.data['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
...@@ -57,13 +57,13 @@ class GCN(nn.Module): ...@@ -57,13 +57,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr({'h' : features}) self.g.ndata['h'] = features
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
self.g.apply_nodes(apply_node_func= self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])}) lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'), self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'), fn.sum(msg='m', out='h'),
layer) layer)
......
import networkx as nx
import torch
import dgl
import dgl.function as fn
N = 100
g = nx.nx.erdos_renyi_graph(N, 0.05)
g = dgl.DGLGraph(g)
DAMP = 0.85
K = 10
def compute_pagerank(g):
g.ndata['pv'] = torch.ones(N) / N
degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees
g.update_all(message_func=fn.copy_src(src='pv', out='m'),
reduce_func=fn.sum(msg='m', out='pv'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
return g.ndata['pv']
pv = compute_pagerank(g)
print(pv)
...@@ -11,3 +11,4 @@ from .base import ALL ...@@ -11,3 +11,4 @@ from .base import ALL
from .batched_graph import * from .batched_graph import *
from .graph import DGLGraph from .graph import DGLGraph
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
from .udf import NodeBatch, EdgeBatch
...@@ -10,7 +10,7 @@ __all__ = ["src_mul_edge", "copy_src", "copy_edge"] ...@@ -10,7 +10,7 @@ __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
"""Base builtin message function class.""" """Base builtin message function class."""
def __call__(self, src, edge): def __call__(self, edges):
"""Regular computation of this builtin. """Regular computation of this builtin.
This will be used when optimization is not available. This will be used when optimization is not available.
...@@ -38,14 +38,10 @@ class BundledMessageFunction(MessageFunction): ...@@ -38,14 +38,10 @@ class BundledMessageFunction(MessageFunction):
return False return False
return True return True
def __call__(self, src, edge): def __call__(self, edges):
ret = None ret = dict()
for fn in self.fn_list: for fn in self.fn_list:
msg = fn(src, edge) msg = fn(edges)
if ret is None:
ret = msg
else:
# ret and msg must be dict
ret.update(msg) ret.update(msg)
return ret return ret
...@@ -83,8 +79,9 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -83,8 +79,9 @@ class SrcMulEdgeMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) \ return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, edges):
ret = self.mul_op(src[self.src_field], edge[self.edge_field]) ret = self.mul_op(edges.src[self.src_field],
edges.data[self.edge_field])
return {self.out_field : ret} return {self.out_field : ret}
def name(self): def name(self):
...@@ -98,8 +95,8 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -98,8 +95,8 @@ class CopySrcMessageFunction(MessageFunction):
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, edges):
return {self.out_field : src[self.src_field]} return {self.out_field : edges.src[self.src_field]}
def name(self): def name(self):
return "copy_src" return "copy_src"
...@@ -114,15 +111,8 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -114,15 +111,8 @@ class CopyEdgeMessageFunction(MessageFunction):
return False return False
# return _is_spmv_supported_edge_feat(g, self.edge_field) # return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, edges):
if self.edge_field is not None: return {self.out_field : edges.data[self.edge_field]}
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "copy_edge" return "copy_edge"
......
...@@ -8,7 +8,7 @@ __all__ = ["sum", "max"] ...@@ -8,7 +8,7 @@ __all__ = ["sum", "max"]
class ReduceFunction(object): class ReduceFunction(object):
"""Base builtin reduce function class.""" """Base builtin reduce function class."""
def __call__(self, node, msgs): def __call__(self, nodes):
"""Regular computation of this builtin. """Regular computation of this builtin.
This will be used when optimization is not available. This will be used when optimization is not available.
...@@ -35,14 +35,10 @@ class BundledReduceFunction(ReduceFunction): ...@@ -35,14 +35,10 @@ class BundledReduceFunction(ReduceFunction):
return False return False
return True return True
def __call__(self, node, msgs): def __call__(self, nodes):
ret = None ret = dict()
for fn in self.fn_list: for fn in self.fn_list:
rpr = fn(node, msgs) rpr = fn(nodes)
if ret is None:
ret = rpr
else:
# ret and rpr must be dict
ret.update(rpr) ret.update(rpr)
return ret return ret
...@@ -60,8 +56,8 @@ class ReducerFunctionTemplate(ReduceFunction): ...@@ -60,8 +56,8 @@ class ReducerFunctionTemplate(ReduceFunction):
# NOTE: only sum is supported right now. # NOTE: only sum is supported right now.
return self.name == "sum" return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, nodes):
return {self.out_field : self.op(msgs[self.msg_field], 1)} return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}
def name(self): def name(self):
return self.name return self.name
......
This diff is collapsed.
...@@ -5,10 +5,11 @@ import numpy as np ...@@ -5,10 +5,11 @@ import numpy as np
from .base import ALL, DGLError from .base import ALL, DGLError
from . import backend as F from . import backend as F
from collections import defaultdict as ddict
from .function import message as fmsg from .function import message as fmsg
from .function import reducer as fred from .function import reducer as fred
from .udf import NodeBatch, EdgeBatch
from . import utils from . import utils
from collections import defaultdict as ddict
from ._ffi.function import _init_api from ._ffi.function import _init_api
...@@ -176,7 +177,7 @@ class DegreeBucketingExecutor(Executor): ...@@ -176,7 +177,7 @@ class DegreeBucketingExecutor(Executor):
# loop over each bucket # loop over each bucket
# FIXME (lingfan): handle zero-degree case # FIXME (lingfan): handle zero-degree case
for deg, vv, msg_id in zip(self.degrees, self.dsts, self.msg_ids): for deg, vv, msg_id in zip(self.degrees, self.dsts, self.msg_ids):
dst_reprs = self.g.get_n_repr(vv) v_data = self.g.get_n_repr(vv)
in_msgs = self.msg_frame.select_rows(msg_id) in_msgs = self.msg_frame.select_rows(msg_id)
def _reshape_fn(msg): def _reshape_fn(msg):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
...@@ -184,7 +185,8 @@ class DegreeBucketingExecutor(Executor): ...@@ -184,7 +185,8 @@ class DegreeBucketingExecutor(Executor):
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs)) nb = NodeBatch(self.g, vv, v_data, reshaped_in_msgs)
new_reprs.append(self.rfunc(nb))
# Pack all reducer results together # Pack all reducer results together
keys = new_reprs[0].keys() keys = new_reprs[0].keys()
...@@ -320,7 +322,7 @@ class SendRecvExecutor(BasicExecutor): ...@@ -320,7 +322,7 @@ class SendRecvExecutor(BasicExecutor):
@property @property
def edge_repr(self): def edge_repr(self):
if self._edge_repr is None: if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v) self._edge_repr = self.g.get_e_repr((self.u, self.v))
return self._edge_repr return self._edge_repr
def _build_adjmat(self): def _build_adjmat(self):
...@@ -432,8 +434,11 @@ def _create_send_and_recv_exec(graph, **kwargs): ...@@ -432,8 +434,11 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst') dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func') mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func') rfunc = kwargs.pop('reduce_func')
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)): if (isinstance(mfunc, fmsg.BundledMessageFunction)
or isinstance(rfunc, fred.BundledReduceFunction)):
if not isinstance(mfunc, fmsg.BundledMessageFunction):
mfunc = fmsg.BundledMessageFunction(mfunc) mfunc = fmsg.BundledMessageFunction(mfunc)
if not isinstance(rfunc, fred.BundledReduceFunction):
rfunc = fred.BundledReduceFunction(rfunc) rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledSendRecvExecutor exec_cls = BundledSendRecvExecutor
else: else:
......
"""User-defined function related data structures."""
from __future__ import absolute_import
from collections import Mapping
from .base import ALL, is_all
from . import backend as F
from . import utils
class EdgeBatch(object):
"""The object that represents a batch of edges.
Parameters
----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL
src_data : dict of tensors
The src node features
edge_data : dict of tensors
The edge features.
dst_data : dict of tensors
The dst node features
"""
def __init__(self, g, edges, src_data, edge_data, dst_data):
self._g = g
self._edges = edges
self._src_data = src_data
self._edge_data = edge_data
self._dst_data = dst_data
@property
def src(self):
"""Return the feature data of the source nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._src_data
@property
def dst(self):
"""Return the feature data of the destination nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._dst_data
@property
def data(self):
"""Return the edge feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._edge_data
def edges(self):
"""Return the edges contained in this batch.
Returns
-------
tuple of tensors
The edge tuple (u, v, eid).
"""
if is_all(self._edges[2]):
self._edges[2] = utils.toindex(F.arange(
0, self._g.number_of_edges(), dtype=F.int64))
u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
def batch_size(self):
"""Return the number of edges in this edge batch."""
return len(self._edges[0])
def __len__(self):
"""Return the number of edges in this edge batch."""
return self.batch_size()
class NodeBatch(object):
"""The object that represents a batch of nodes.
Parameters
----------
g : DGLGraph
The graph object.
nodes : utils.Index or ALL
The node ids.
data : dict of tensors
The node features
msgs : dict of tensors, optional
The messages.
"""
def __init__(self, g, nodes, data, msgs=None):
self._g = g
self._nodes = nodes
self._data = data
self._msgs = msgs
@property
def data(self):
"""Return the node feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._data
@property
def mailbox(self):
"""Return the received messages.
If no messages received, a None will be returned.
Returns
-------
dict of str to tensors
The message data.
"""
return self._msgs
def nodes(self):
"""Return the nodes contained in this batch.
Returns
-------
tensor
The nodes.
"""
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes(), dtype=F.int64))
return self._nodes.tousertensor()
def batch_size(self):
"""Return the number of nodes in this node batch."""
if is_all(self._nodes):
return self._g.number_of_nodes()
else:
return len(self._nodes)
def __len__(self):
"""Return the number of nodes in this node batch."""
return self.batch_size()
...@@ -20,8 +20,14 @@ class Index(object): ...@@ -20,8 +20,14 @@ class Index(object):
def _dispatch(self, data): def _dispatch(self, data):
"""Store data based on its type.""" """Store data based on its type."""
if isinstance(data, Tensor): if isinstance(data, Tensor):
if not (F.dtype(data) == F.int64 and len(F.shape(data)) == 1): if not (F.dtype(data) == F.int64):
raise ValueError('Index data must be an int64 vector, but got: %s' % str(data))
if len(F.shape(data)) > 1:
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data)) raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
if len(F.shape(data)) == 0:
# a tensor of one int
self._dispatch(int(data))
else:
self._user_tensor_data[F.get_context(data)] = data self._user_tensor_data[F.get_context(data)] = data
elif isinstance(data, nd.NDArray): elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1): if not (data.dtype == 'int64' and len(data.shape) == 1):
...@@ -343,3 +349,18 @@ def reorder(dict_like, index): ...@@ -343,3 +349,18 @@ def reorder(dict_like, index):
idx_ctx = index.tousertensor(F.get_context(val)) idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
def parse_edges_tuple(edges):
"""Parse the given edges and return the tuple.
Parameters
----------
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns
-------
A tuple of (u, v, eid)
"""
pass
"""Views of DGLGraph."""
from __future__ import absolute_import
from collections import MutableMapping, namedtuple
from .base import ALL, is_all, DGLError
from . import backend as F
from . import utils
NodeSpace = namedtuple('NodeSpace', ['data'])
class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph.
Compared with networkx's NodeView, DGL's NodeView is not
a map. DGL's NodeView supports creating data view from
a given list of nodes.
Parameters
----------
graph : DGLGraph
The graph.
Examples
--------
TBD
"""
__slot__ = '_graph'
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_nodes()
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=NodeDataView(self._graph, ALL))
else:
return NodeSpace(data=NodeDataView(self._graph, nodes))
def __call__(self):
"""Return the nodes."""
return F.arange(0, len(self))
class NodeDataView(MutableMapping):
__slot__ = ['_graph', '_nodes']
def __init__(self, graph, nodes):
self._graph = graph
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr({key : val}, self._nodes)
def __delitem__(self, key):
if not is_all(self._nodes):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
self._graph.pop_n_repr(key)
def __len__(self):
return len(self._graph._node_frame)
def __iter__(self):
return iter(self._graph._node_frame)
def __repr__(self):
data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object):
__slot__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_edges()
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=EdgeDataView(self._graph, ALL))
else:
return EdgeSpace(data=EdgeDataView(self._graph, edges))
def __call__(self, *args, **kwargs):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
class EdgeDataView(MutableMapping):
__slot__ = ['_graph', '_edges']
def __init__(self, graph, edges):
self._graph = graph
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr({key : val}, self._edges)
def __delitem__(self, key):
if not is_all(self._edges):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.edata[key]` instead.')
self._graph.pop_e_repr(key)
def __len__(self):
return len(self._graph._edge_frame)
def __iter__(self):
return iter(self._graph._edge_frame)
def __repr__(self):
data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame})
...@@ -11,20 +11,20 @@ def check_eq(a, b): ...@@ -11,20 +11,20 @@ def check_eq(a, b):
assert a.shape == b.shape assert a.shape == b.shape
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape))) assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def message_func(src, edge): def message_func(edges):
assert len(src['h'].shape) == 2 assert len(edges.src['h'].shape) == 2
assert src['h'].shape[1] == D assert edges.src['h'].shape[1] == D
return {'m' : src['h']} return {'m' : edges.src['h']}
def reduce_func(node, msgs): def reduce_func(nodes):
msgs = msgs['m'] msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape)) reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3 assert len(msgs.shape) == 3
assert msgs.shape[2] == D assert msgs.shape[2] == D
return {'m' : mx.nd.sum(msgs, 1)} return {'m' : mx.nd.sum(msgs, 1)}
def apply_node_func(node): def apply_node_func(nodes):
return {'h' : node['h'] + node['m']} return {'h' : nodes.data['h'] + nodes.data['m']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
...@@ -38,7 +38,7 @@ def generate_graph(grad=False): ...@@ -38,7 +38,7 @@ def generate_graph(grad=False):
ncol = mx.nd.random.normal(shape=(10, D)) ncol = mx.nd.random.normal(shape=(10, D))
if grad: if grad:
ncol.attach_grad() ncol.attach_grad()
g.set_n_repr({'h' : ncol}) g.ndata['h'] = ncol
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -47,15 +47,15 @@ def test_batch_setter_getter(): ...@@ -47,15 +47,15 @@ def test_batch_setter_getter():
g = generate_graph() g = generate_graph()
# set all nodes # set all nodes
g.set_n_repr({'h' : mx.nd.zeros((10, D))}) g.set_n_repr({'h' : mx.nd.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10 assert _pfc(g.ndata['h']) == [0.] * 10
# pop nodes # pop nodes
assert _pfc(g.pop_n_repr('h')) == [0.] * 10 assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0 assert len(g.ndata) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))}) g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes # set partial nodes
u = mx.nd.array([1, 3, 5], dtype='int64') u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u) g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.] assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes # get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64') u = mx.nd.array([1, 2, 3], dtype='int64')
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.] assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
...@@ -81,77 +81,77 @@ def test_batch_setter_getter(): ...@@ -81,77 +81,77 @@ def test_batch_setter_getter():
9, 0, 16 9, 0, 16
''' '''
# set all edges # set all edges
g.set_e_repr({'l' : mx.nd.zeros((17, D))}) g.edata['l'] = mx.nd.zeros((17, D))
assert _pfc(g.get_e_repr()['l']) == [0.] * 17 assert _pfc(g.edata['l']) == [0.] * 17
# pop edges # pop edges
assert _pfc(g.pop_e_repr('l')) == [0.] * 17 assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0 assert len(g.edata) == 0
g.set_e_repr({'l' : mx.nd.zeros((17, D))}) g.edata['l'] = mx.nd.zeros((17, D))
# set partial edges (many-many) # set partial edges (many-many)
u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64') u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64') v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((5, D))}, u, v) g.edges[u, v].data['l'] = mx.nd.ones((5, D))
truth = [0.] * 17 truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1. truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (many-one) # set partial edges (many-one)
u = mx.nd.array([3, 4, 6], dtype='int64') u = mx.nd.array([3, 4, 6], dtype='int64')
v = mx.nd.array([9], dtype='int64') v = mx.nd.array([9], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v) g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[5] = truth[7] = truth[11] = 1. truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (one-many) # set partial edges (one-many)
u = mx.nd.array([0], dtype='int64') u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([4, 5, 6], dtype='int64') v = mx.nd.array([4, 5, 6], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v) g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[6] = truth[8] = truth[10] = 1. truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# get partial edges (many-many) # get partial edges (many-many)
u = mx.nd.array([0, 6, 0], dtype='int64') u = mx.nd.array([0, 6, 0], dtype='int64')
v = mx.nd.array([6, 9, 7], dtype='int64') v = mx.nd.array([6, 9, 7], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one) # get partial edges (many-one)
u = mx.nd.array([5, 6, 7], dtype='int64') u = mx.nd.array([5, 6, 7], dtype='int64')
v = mx.nd.array([9], dtype='int64') v = mx.nd.array([9], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many) # get partial edges (one-many)
u = mx.nd.array([0], dtype='int64') u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([3, 4, 5], dtype='int64') v = mx.nd.array([3, 4, 5], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd(): def test_batch_setter_autograd():
with mx.autograd.record(): with mx.autograd.record():
g = generate_graph(grad=True) g = generate_graph(grad=True)
h1 = g.get_n_repr()['h'] h1 = g.ndata['h']
h1.attach_grad() h1.attach_grad()
# partial set # partial set
v = mx.nd.array([1, 2, 8], dtype='int64') v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D)) hh = mx.nd.zeros((len(v), D))
hh.attach_grad() hh.attach_grad()
g.set_n_repr({'h' : hh}, v) g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h'] h2 = g.ndata['h']
h2.backward(mx.nd.ones((10, D)) * 2) h2.backward(mx.nd.ones((10, D)) * 2)
check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.])) check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.])) check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(src, edge): def _fmsg(edges):
assert src['h'].shape == (5, D) assert edges.src['h'].shape == (5, D)
return {'m' : src['h']} return {'m' : edges.src['h']}
g.register_message_func(_fmsg) g.register_message_func(_fmsg)
# many-many send # many-many send
u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64') u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64') v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v) g.send((u, v))
# one-many send # one-many send
u = mx.nd.array([0], dtype='int64') u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64') v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v) g.send((u, v))
# many-one send # many-one send
u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64') u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
v = mx.nd.array([9], dtype='int64') v = mx.nd.array([9], dtype='int64')
g.send(u, v) g.send((u, v))
def test_batch_recv(): def test_batch_recv():
# basic recv test # basic recv test
...@@ -162,7 +162,7 @@ def test_batch_recv(): ...@@ -162,7 +162,7 @@ def test_batch_recv():
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64') u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64') v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
g.send(u, v) g.send((u, v))
#g.recv(th.unique(v)) #g.recv(th.unique(v))
#assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) #assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
#reduce_msg_shapes.clear() #reduce_msg_shapes.clear()
...@@ -177,7 +177,7 @@ def test_update_routines(): ...@@ -177,7 +177,7 @@ def test_update_routines():
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64') u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64') v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
g.send_and_recv(u, v) g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -208,14 +208,14 @@ def test_reduce_0deg(): ...@@ -208,14 +208,14 @@ def test_reduce_0deg():
g.add_edge(2, 0) g.add_edge(2, 0)
g.add_edge(3, 0) g.add_edge(3, 0)
g.add_edge(4, 0) g.add_edge(4, 0)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : node['h'] + msgs['m'].sum(1)} return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(5, 5)) old_repr = mx.nd.random.normal(shape=(5, 5))
g.set_n_repr({'h': old_repr}) g.set_n_repr({'h': old_repr})
g.update_all(_message, _reduce) g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy()) assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
...@@ -224,25 +224,25 @@ def test_pull_0deg(): ...@@ -224,25 +224,25 @@ def test_pull_0deg():
g = DGLGraph() g = DGLGraph()
g.add_nodes(2) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : msgs['m'].sum(1)} return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(2, 5)) old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr}) g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce) g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
g.pull(1, _message, _reduce) g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
old_repr = mx.nd.random.normal(shape=(2, 5)) old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr}) g.set_n_repr({'h' : old_repr})
g.pull([0, 1], _message, _reduce) g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
......
...@@ -10,20 +10,20 @@ def check_eq(a, b): ...@@ -10,20 +10,20 @@ def check_eq(a, b):
assert a.shape == b.shape assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape))) assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(src, edge): def message_func(edges):
assert len(src['h'].shape) == 2 assert len(edges.src['h'].shape) == 2
assert src['h'].shape[1] == D assert edges.src['h'].shape[1] == D
return {'m' : src['h']} return {'m' : edges.src['h']}
def reduce_func(node, msgs): def reduce_func(nodes):
msgs = msgs['m'] msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape)) reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3 assert len(msgs.shape) == 3
assert msgs.shape[2] == D assert msgs.shape[2] == D
return {'accum' : th.sum(msgs, 1)} return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node): def apply_node_func(nodes):
return {'h' : node['h'] + node['accum']} return {'h' : nodes.data['h'] + nodes.data['accum']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
...@@ -36,10 +36,11 @@ def generate_graph(grad=False): ...@@ -36,10 +36,11 @@ def generate_graph(grad=False):
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad) ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol}) g.ndata['h'] = ncol
g.edata['w'] = ecol
g.set_n_initializer(lambda shape, dtype : th.zeros(shape)) g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
g.set_e_initializer(lambda shape, dtype : th.zeros(shape))
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -47,20 +48,20 @@ def test_batch_setter_getter(): ...@@ -47,20 +48,20 @@ def test_batch_setter_getter():
return list(x.numpy()[:,0]) return list(x.numpy()[:,0])
g = generate_graph() g = generate_graph()
# set all nodes # set all nodes
g.set_n_repr({'h' : th.zeros((10, D))}) g.ndata['h'] = th.zeros((10, D))
assert _pfc(g.get_n_repr()['h']) == [0.] * 10 assert th.allclose(g.ndata['h'], th.zeros((10, D)))
# pop nodes # pop nodes
old_len = len(g.get_n_repr()) old_len = len(g.ndata)
assert _pfc(g.pop_n_repr('h')) == [0.] * 10 assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == old_len - 1 assert len(g.ndata) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))}) g.ndata['h'] = th.zeros((10, D))
# set partial nodes # set partial nodes
u = th.tensor([1, 3, 5]) u = th.tensor([1, 3, 5])
g.set_n_repr({'h' : th.ones((3, D))}, u) g.nodes[u].data['h'] = th.ones((3, D))
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.] assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes # get partial nodes
u = th.tensor([1, 2, 3]) u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.] assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
''' '''
s, d, eid s, d, eid
...@@ -83,75 +84,75 @@ def test_batch_setter_getter(): ...@@ -83,75 +84,75 @@ def test_batch_setter_getter():
9, 0, 16 9, 0, 16
''' '''
# set all edges # set all edges
g.set_e_repr({'l' : th.zeros((17, D))}) g.edata['l'] = th.zeros((17, D))
assert _pfc(g.get_e_repr()['l']) == [0.] * 17 assert _pfc(g.edata['l']) == [0.] * 17
# pop edges # pop edges
old_len = len(g.get_e_repr()) old_len = len(g.edata)
assert _pfc(g.pop_e_repr('l')) == [0.] * 17 assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == old_len - 1 assert len(g.edata) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))}) g.edata['l'] = th.zeros((17, D))
# set partial edges (many-many) # set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9]) u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0]) v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr({'l' : th.ones((5, D))}, u, v) g.edges[u, v].data['l'] = th.ones((5, D))
truth = [0.] * 17 truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1. truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (many-one) # set partial edges (many-one)
u = th.tensor([3, 4, 6]) u = th.tensor([3, 4, 6])
v = th.tensor([9]) v = th.tensor([9])
g.set_e_repr({'l' : th.ones((3, D))}, u, v) g.edges[u, v].data['l'] = th.ones((3, D))
truth[5] = truth[7] = truth[11] = 1. truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (one-many) # set partial edges (one-many)
u = th.tensor([0]) u = th.tensor([0])
v = th.tensor([4, 5, 6]) v = th.tensor([4, 5, 6])
g.set_e_repr({'l' : th.ones((3, D))}, u, v) g.edges[u, v].data['l'] = th.ones((3, D))
truth[6] = truth[8] = truth[10] = 1. truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# get partial edges (many-many) # get partial edges (many-many)
u = th.tensor([0, 6, 0]) u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7]) v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one) # get partial edges (many-one)
u = th.tensor([5, 6, 7]) u = th.tensor([5, 6, 7])
v = th.tensor([9]) v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many) # get partial edges (one-many)
u = th.tensor([0]) u = th.tensor([0])
v = th.tensor([3, 4, 5]) v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd(): def test_batch_setter_autograd():
g = generate_graph(grad=True) g = generate_graph(grad=True)
h1 = g.get_n_repr()['h'] h1 = g.ndata['h']
# partial set # partial set
v = th.tensor([1, 2, 8]) v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True) hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr({'h' : hh}, v) g.nodes[v].data['h'] = hh
h2 = g.get_n_repr()['h'] h2 = g.ndata['h']
h2.backward(th.ones((10, D)) * 2) h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.])) check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.])) check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(src, edge): def _fmsg(edges):
assert src['h'].shape == (5, D) assert edges.src['h'].shape == (5, D)
return {'m' : src['h']} return {'m' : edges.src['h']}
g.register_message_func(_fmsg) g.register_message_func(_fmsg)
# many-many send # many-many send
u = th.tensor([0, 0, 0, 0, 0]) u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5]) v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v) g.send((u, v))
# one-many send # one-many send
u = th.tensor([0]) u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5]) v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v) g.send((u, v))
# many-one send # many-one send
u = th.tensor([1, 2, 3, 4, 5]) u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9]) v = th.tensor([9])
g.send(u, v) g.send((u, v))
def test_batch_recv(): def test_batch_recv():
# basic recv test # basic recv test
...@@ -162,11 +163,25 @@ def test_batch_recv(): ...@@ -162,11 +163,25 @@ def test_batch_recv():
u = th.tensor([0, 0, 0, 4, 5, 6]) u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9]) v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
g.send(u, v) g.send((u, v))
g.recv(th.unique(v)) g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_update_edges():
def _upd(edges):
return {'w' : edges.data['w'] * 2}
g = generate_graph()
g.register_edge_func(_upd)
old = g.edata['w']
g.update_edges()
assert th.allclose(old * 2, g.edata['w'])
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_edges((u, v), lambda edges : {'w' : edges.data['w'] * 0.})
eid = g.edge_ids(u, v)
assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))
def test_update_routines(): def test_update_routines():
g = generate_graph() g = generate_graph()
g.register_message_func(message_func) g.register_message_func(message_func)
...@@ -177,7 +192,7 @@ def test_update_routines(): ...@@ -177,7 +192,7 @@ def test_update_routines():
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6]) u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9]) v = th.tensor([1, 2, 3, 9, 9, 9])
g.send_and_recv(u, v) g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -208,14 +223,14 @@ def test_reduce_0deg(): ...@@ -208,14 +223,14 @@ def test_reduce_0deg():
g.add_edge(2, 0) g.add_edge(2, 0)
g.add_edge(3, 0) g.add_edge(3, 0)
g.add_edge(4, 0) g.add_edge(4, 0)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : node['h'] + msgs['m'].sum(1)} return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = th.randn(5, 5) old_repr = th.randn(5, 5)
g.set_n_repr({'h' : old_repr}) g.ndata['h'] = old_repr
g.update_all(_message, _reduce) g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[1:], old_repr[1:]) assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0)) assert th.allclose(new_repr[0], old_repr.sum(0))
...@@ -224,26 +239,26 @@ def test_pull_0deg(): ...@@ -224,26 +239,26 @@ def test_pull_0deg():
g = DGLGraph() g = DGLGraph()
g.add_nodes(2) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : msgs['m'].sum(1)} return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr}) g.ndata['h'] = old_repr
g.pull(0, _message, _reduce) g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1]) assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce) g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr}) g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce) g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
...@@ -253,27 +268,26 @@ def _disabled_test_send_twice(): ...@@ -253,27 +268,26 @@ def _disabled_test_send_twice():
g.add_nodes(3) g.add_nodes(3)
g.add_edge(0, 1) g.add_edge(0, 1)
g.add_edge(2, 1) g.add_edge(2, 1)
def _message_a(src, edge): def _message_a(edges):
return {'a': src['a']} return {'a': edges.src['a']}
def _message_b(src, edge): def _message_b(edges):
return {'a': src['a'] * 3} return {'a': edges.src['a'] * 3}
def _reduce(node, msgs): def _reduce(nodes):
assert msgs is not None return {'a': nodes.mailbox['a'].max(1)[0]}
return {'a': msgs['a'].max(1)[0]}
old_repr = th.randn(3, 5) old_repr = th.randn(3, 5)
g.set_n_repr({'a': old_repr}) g.ndata['a'] = old_repr
g.send(0, 1, _message_a) g.send((0, 1), _message_a)
g.send(0, 1, _message_b) g.send((0, 1), _message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr[0] * 3) assert th.allclose(new_repr[1], old_repr[0] * 3)
g.set_n_repr({'a': old_repr}) g.ndata['a'] = old_repr
g.send(0, 1, _message_a) g.send((0, 1), _message_a)
g.send(2, 1, _message_b) g.send((2, 1), _message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0]) assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_multigraph(): def test_send_multigraph():
...@@ -284,64 +298,63 @@ def test_send_multigraph(): ...@@ -284,64 +298,63 @@ def test_send_multigraph():
g.add_edge(0, 1) g.add_edge(0, 1)
g.add_edge(2, 1) g.add_edge(2, 1)
def _message_a(src, edge): def _message_a(edges):
return {'a': edge['a']} return {'a': edges.data['a']}
def _message_b(src, edge): def _message_b(edges):
return {'a': edge['a'] * 3} return {'a': edges.data['a'] * 3}
def _reduce(node, msgs): def _reduce(nodes):
assert msgs is not None return {'a': nodes.mailbox['a'].max(1)[0]}
return {'a': msgs['a'].max(1)[0]}
def answer(*args): def answer(*args):
return th.stack(args, 0).max(0)[0] return th.stack(args, 0).max(0)[0]
# send by eid # send by eid
old_repr = th.randn(4, 5) old_repr = th.randn(4, 5)
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(eid=[0, 2], message_func=_message_a) g.send([0, 2], message_func=_message_a)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2])) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(eid=[0, 2, 3], message_func=_message_a) g.send([0, 2, 3], message_func=_message_a)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
# send on multigraph # send on multigraph
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send([0, 2], [1, 1], _message_a) g.send(([0, 2], [1, 1]), _message_a)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr.max(0)[0]) assert th.allclose(new_repr[1], old_repr.max(0)[0])
# consecutive send and send_on # consecutive send and send_on
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(2, 1, _message_a) g.send((2, 1), _message_a)
g.send(eid=[0, 1], message_func=_message_b) g.send([0, 1], message_func=_message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3])) assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
# consecutive send_on # consecutive send_on
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(eid=0, message_func=_message_a) g.send(0, message_func=_message_a)
g.send(eid=1, message_func=_message_b) g.send(1, message_func=_message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3)) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
# send_and_recv_on # send_and_recv_on
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send_and_recv(eid=[0, 2, 3], message_func=_message_a, reduce_func=_reduce) g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5)) assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))
...@@ -353,29 +366,25 @@ def test_dynamic_addition(): ...@@ -353,29 +366,25 @@ def test_dynamic_addition():
# Test node addition # Test node addition
g.add_nodes(N) g.add_nodes(N)
g.set_n_repr({'h1': th.randn(N, D), g.ndata.update({'h1': th.randn(N, D),
'h2': th.randn(N, D)}) 'h2': th.randn(N, D)})
g.add_nodes(3) g.add_nodes(3)
n_repr = g.get_n_repr() assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
assert n_repr['h1'].shape[0] == n_repr['h2'].shape[0] == N + 3
# Test edge addition # Test edge addition
g.add_edge(0, 1) g.add_edge(0, 1)
g.add_edge(1, 0) g.add_edge(1, 0)
g.set_e_repr({'h1': th.randn(2, D), g.edata.update({'h1': th.randn(2, D),
'h2': th.randn(2, D)}) 'h2': th.randn(2, D)})
e_repr = g.get_e_repr() assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 2
g.add_edges([0, 2], [2, 0]) g.add_edges([0, 2], [2, 0])
e_repr = g.get_e_repr() g.edata['h1'] = th.randn(4, D)
g.set_e_repr({'h1': th.randn(4, D)}) assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 4
g.add_edge(1, 2) g.add_edge(1, 2)
g.set_e_repr_by_id({'h1': th.randn(1, D)}, eid=4) g.edges[4].data['h1'] = th.randn(1, D)
e_repr = g.get_e_repr() assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 5
if __name__ == '__main__': if __name__ == '__main__':
...@@ -383,6 +392,7 @@ if __name__ == '__main__': ...@@ -383,6 +392,7 @@ if __name__ == '__main__':
test_batch_setter_autograd() test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_edges()
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
......
...@@ -18,8 +18,8 @@ def tree1(): ...@@ -18,8 +18,8 @@ def tree1():
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(1, 0) g.add_edge(1, 0)
g.add_edge(2, 0) g.add_edge(2, 0)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])}) g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.set_e_repr({'h' : th.randn(4, 10)}) g.edata['h'] = th.randn(4, 10)
return g return g
def tree2(): def tree2():
...@@ -37,17 +37,17 @@ def tree2(): ...@@ -37,17 +37,17 @@ def tree2():
g.add_edge(0, 4) g.add_edge(0, 4)
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(3, 1) g.add_edge(3, 1)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])}) g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.set_e_repr({'h' : th.randn(4, 10)}) g.edata['h'] = th.randn(4, 10)
return g return g
def test_batch_unbatch(): def test_batch_unbatch():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
n1 = t1.get_n_repr()['h'] n1 = t1.ndata['h']
n2 = t2.get_n_repr()['h'] n2 = t2.ndata['h']
e1 = t1.get_e_repr()['h'] e1 = t1.edata['h']
e2 = t2.get_e_repr()['h'] e2 = t2.edata['h']
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10 assert bg.number_of_nodes() == 10
...@@ -57,10 +57,10 @@ def test_batch_unbatch(): ...@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert bg.batch_num_edges == [4, 4] assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg) tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr()['h'], tt1.get_n_repr()['h']) assert th.allclose(t1.ndata['h'], tt1.ndata['h'])
assert th.allclose(t1.get_e_repr()['h'], tt1.get_e_repr()['h']) assert th.allclose(t1.edata['h'], tt1.edata['h'])
assert th.allclose(t2.get_n_repr()['h'], tt2.get_n_repr()['h']) assert th.allclose(t2.ndata['h'], tt2.ndata['h'])
assert th.allclose(t2.get_e_repr()['h'], tt2.get_e_repr()['h']) assert th.allclose(t2.edata['h'], tt2.edata['h'])
def test_batch_unbatch1(): def test_batch_unbatch1():
t1 = tree1() t1 = tree1()
...@@ -74,29 +74,29 @@ def test_batch_unbatch1(): ...@@ -74,29 +74,29 @@ def test_batch_unbatch1():
assert b2.batch_num_edges == [4, 4, 4] assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2) s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr()['h'], s1.get_n_repr()['h']) assert th.allclose(t2.ndata['h'], s1.ndata['h'])
assert th.allclose(t2.get_e_repr()['h'], s1.get_e_repr()['h']) assert th.allclose(t2.edata['h'], s1.edata['h'])
assert th.allclose(t1.get_n_repr()['h'], s2.get_n_repr()['h']) assert th.allclose(t1.ndata['h'], s2.ndata['h'])
assert th.allclose(t1.get_e_repr()['h'], s2.get_e_repr()['h']) assert th.allclose(t1.edata['h'], s2.edata['h'])
assert th.allclose(t2.get_n_repr()['h'], s3.get_n_repr()['h']) assert th.allclose(t2.ndata['h'], s3.ndata['h'])
assert th.allclose(t2.get_e_repr()['h'], s3.get_e_repr()['h']) assert th.allclose(t2.edata['h'], s3.edata['h'])
def test_batch_sendrecv(): def test_batch_sendrecv():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']}) bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)}) bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
u = [3, 4, 2 + 5, 0 + 5] u = [3, 4, 2 + 5, 0 + 5]
v = [1, 1, 4 + 5, 4 + 5] v = [1, 1, 4 + 5, 4 + 5]
bg.send(u, v) bg.send((u, v))
bg.recv(v) bg.recv(v)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][1] == 7 assert t1.ndata['h'][1] == 7
assert t2.get_n_repr()['h'][4] == 2 assert t2.ndata['h'][4] == 2
def test_batch_propagate(): def test_batch_propagate():
...@@ -104,8 +104,8 @@ def test_batch_propagate(): ...@@ -104,8 +104,8 @@ def test_batch_propagate():
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']}) bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)}) bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
# get leaves. # get leaves.
order = [] order = []
...@@ -123,23 +123,23 @@ def test_batch_propagate(): ...@@ -123,23 +123,23 @@ def test_batch_propagate():
bg.propagate(traverser=order) bg.propagate(traverser=order)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][0] == 9 assert t1.ndata['h'][0] == 9
assert t2.get_n_repr()['h'][1] == 5 assert t2.ndata['h'][1] == 5
def test_batched_edge_ordering(): def test_batched_edge_ordering():
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
g1.add_nodes(6) g1.add_nodes(6)
g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1]) g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
e1 = th.randn(5, 10) e1 = th.randn(5, 10)
g1.set_e_repr({'h' : e1}) g1.edata['h'] = e1
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g2.add_nodes(6) g2.add_nodes(6)
g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0]) g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
e2 = th.randn(6, 10) e2 = th.randn(6, 10)
g2.set_e_repr({'h' : e2}) g2.edata['h'] = e2
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
r1 = g.get_e_repr()['h'][g.edge_id(4, 5)] r1 = g.edata['h'][g.edge_id(4, 5)]
r2 = g1.get_e_repr()['h'][g1.edge_id(4, 5)] r2 = g1.edata['h'][g1.edge_id(4, 5)]
assert th.equal(r1, r2) assert th.equal(r1, r2)
def test_batch_no_edge(): def test_batch_no_edge():
......
...@@ -12,8 +12,8 @@ def test_filter(): ...@@ -12,8 +12,8 @@ def test_filter():
n_repr[[1, 3]] = 1 n_repr[[1, 3]] = 1
e_repr[[1, 3]] = 1 e_repr[[1, 3]] = 1
g.set_n_repr({'a': n_repr}) g.ndata['a'] = n_repr
g.set_e_repr({'a': e_repr}) g.edata['a'] = e_repr
def predicate(r): def predicate(r):
return r['a'].max(1)[0] > 0 return r['a'].max(1)[0] > 0
......
...@@ -6,7 +6,7 @@ def generate_graph(): ...@@ -6,7 +6,7 @@ def generate_graph():
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes. g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float) h = th.arange(1, 11, dtype=th.float)
g.set_n_repr({'h': h}) g.ndata['h'] = h
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
...@@ -15,29 +15,11 @@ def generate_graph(): ...@@ -15,29 +15,11 @@ def generate_graph():
g.add_edge(9, 0) g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\ h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.]) 1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr({'h' : h}) g.edata['h'] = h
return g return g
def generate_graph1(): def reducer_both(nodes):
"""graph with anonymous repr""" return {'h' : th.sum(nodes.mailbox['m'], 1)}
g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float)
h = th.arange(1, 11, dtype=th.float)
g.set_n_repr(h)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr(h)
return g
def reducer_both(node, msgs):
return {'h' : th.sum(msgs['m'], 1)}
def test_copy_src(): def test_copy_src():
# copy_src with both fields # copy_src with both fields
...@@ -45,7 +27,7 @@ def test_copy_src(): ...@@ -45,7 +27,7 @@ def test_copy_src():
g.register_message_func(fn.copy_src(src='h', out='m')) g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge(): def test_copy_edge():
...@@ -54,7 +36,7 @@ def test_copy_edge(): ...@@ -54,7 +36,7 @@ def test_copy_edge():
g.register_message_func(fn.copy_edge(edge='h', out='m')) g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge(): def test_src_mul_edge():
...@@ -63,7 +45,7 @@ def test_src_mul_edge(): ...@@ -63,7 +45,7 @@ def test_src_mul_edge():
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m')) g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.ndata['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment