"docs/vscode:/vscode.git/clone" did not exist on "ed66a209708ad4c0f442ddb21d0678c013b35f89"
Unverified Commit 68ec6247 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API][Doc] API change & basic tutorials (#113)

* Add SH tutorials

* setup sphinx-gallery; work on graph tutorial

* draft dglgraph tutorial

* update readme to include document url

* rm obsolete file

* Draft the message passing tutorial

* Capsule code (#102)

* add capsule example

* clean code

* better naming

* better naming

* [GCN]tutorial scaffold

* fix capsule example code

* remove previous capsule example code

* graph struc edit

* modified:   2_graph.py

* update doc of capsule

* update capsule docs

* update capsule docs

* add msg passing prime

* GCN-GAT tutorial Section 1 and 2

* comment for API improvement

* section 3

* Tutorial API change (#115)

* change the API as discusses; toy example

* enable the new set/get syntax

* fixed pytorch utest

* fixed gcn example

* fixed gat example

* fixed mx utests

* fix mx utest

* delete apply edges; add utest for update_edges

* small change on toy example

* fix utest

* fix out in degrees bug

* update pagerank example and add it to CI

* add delitem for dataview

* make edges() return form that is compatible with send/update_edges etc

* fix index bug when the given data is one-int-tensor

* fix doc
parent 2ecd2b23
...@@ -17,7 +17,7 @@ def build_dgl() { ...@@ -17,7 +17,7 @@ def build_dgl() {
} }
dir ('build') { dir ('build') {
sh 'cmake ..' sh 'cmake ..'
sh 'make -j$(nproc)' sh 'make -j4'
} }
} }
......
...@@ -46,6 +46,7 @@ extensions = [ ...@@ -46,6 +46,7 @@ extensions = [
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery', 'sphinx_gallery.gen_gallery',
] ]
......
from __future__ import division
import networkx as nx
from dgl.graph import DGLGraph
DAMP = 0.85
N = 100
K = 10
def message_func(src, dst, edge):
return src['pv'] / src['deg']
def update_func(node, accum):
pv = (1 - DAMP) / N + DAMP * accum
return {'pv' : pv}
def compute_pagerank(g):
g = DGLGraph(g)
print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# init pv value
for n in g.nodes():
g.node[n]['pv'] = 1 / N
g.node[n]['deg'] = g.out_degree(n)
# pagerank
for k in range(K):
g.update_all()
return [g.node[n]['pv'] for n in g.nodes()]
if __name__ == '__main__':
g = nx.erdos_renyi_graph(N, 0.05)
pv = compute_pagerank(g)
print(pv)
...@@ -16,18 +16,18 @@ import dgl ...@@ -16,18 +16,18 @@ import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gat_message(src, edge): def gat_message(edges):
return {'ft' : src['ft'], 'a2' : src['a2']} return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
class GATReduce(nn.Module): class GATReduce(nn.Module):
def __init__(self, attn_drop): def __init__(self, attn_drop):
super(GATReduce, self).__init__() super(GATReduce, self).__init__()
self.attn_drop = attn_drop self.attn_drop = attn_drop
def forward(self, node, msgs): def forward(self, nodes):
a1 = torch.unsqueeze(node['a1'], 1) # shape (B, 1, 1) a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1) a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D) ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention # attention
a = a1 + a2 # shape (B, deg, 1) a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1) e = F.softmax(F.leaky_relu(a), dim=1)
...@@ -46,13 +46,13 @@ class GATFinalize(nn.Module): ...@@ -46,13 +46,13 @@ class GATFinalize(nn.Module):
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim) self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node): def forward(self, nodes):
ret = node['accum'] ret = nodes.data['accum']
if self.residual: if self.residual:
if self.residual_fc is not None: if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret ret = self.residual_fc(nodes.data['h']) + ret
else: else:
ret = node['h'] + ret ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)} return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(nn.Module): class GATPrepare(nn.Module):
...@@ -120,7 +120,7 @@ class GAT(nn.Module): ...@@ -120,7 +120,7 @@ class GAT(nn.Module):
for hid in range(self.num_heads): for hid in range(self.num_heads):
i = l * self.num_heads + hid i = l * self.num_heads + hid
# prepare # prepare
self.g.set_n_repr(self.prp[i](last)) self.g.ndata.update(self.prp[i](last))
# message passing # message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i]) self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads # merge all the heads
...@@ -128,7 +128,7 @@ class GAT(nn.Module): ...@@ -128,7 +128,7 @@ class GAT(nn.Module):
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)], [self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1) dim=1)
# output projection # output projection
self.g.set_n_repr(self.prp[-1](last)) self.g.ndata.update(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1]) self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0') return self.g.pop_n_repr('head0')
......
...@@ -14,24 +14,22 @@ import torch.nn.functional as F ...@@ -14,24 +14,22 @@ import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
def gcn_msg(src, edge): def gcn_msg(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def gcn_reduce(node, msgs): def gcn_reduce(nodes):
return {'h' : torch.sum(msgs['m'], 1)} return {'h' : torch.sum(nodes.mailbox['m'], 1)}
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, nodes):
h = self.linear(node['h']) h = self.linear(nodes.data['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h' : h} return {'h' : h}
class GCN(nn.Module): class GCN(nn.Module):
...@@ -62,13 +60,13 @@ class GCN(nn.Module): ...@@ -62,13 +60,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr({'h' : features}) self.g.ndata['h'] = features
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
self.g.apply_nodes(apply_node_func= self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])}) lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h') return self.g.pop_n_repr('h')
......
...@@ -22,8 +22,8 @@ class NodeApplyModule(nn.Module): ...@@ -22,8 +22,8 @@ class NodeApplyModule(nn.Module):
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, nodes):
h = self.linear(node['h']) h = self.linear(nodes.data['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
...@@ -57,13 +57,13 @@ class GCN(nn.Module): ...@@ -57,13 +57,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr({'h' : features}) self.g.ndata['h'] = features
for layer in self.layers: for layer in self.layers:
# apply dropout # apply dropout
if self.dropout: if self.dropout:
self.g.apply_nodes(apply_node_func= self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])}) lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'), self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'), fn.sum(msg='m', out='h'),
layer) layer)
......
import networkx as nx
import torch
import dgl
import dgl.function as fn
N = 100
g = nx.nx.erdos_renyi_graph(N, 0.05)
g = dgl.DGLGraph(g)
DAMP = 0.85
K = 10
def compute_pagerank(g):
g.ndata['pv'] = torch.ones(N) / N
degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees
g.update_all(message_func=fn.copy_src(src='pv', out='m'),
reduce_func=fn.sum(msg='m', out='pv'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
return g.ndata['pv']
pv = compute_pagerank(g)
print(pv)
...@@ -11,3 +11,4 @@ from .base import ALL ...@@ -11,3 +11,4 @@ from .base import ALL
from .batched_graph import * from .batched_graph import *
from .graph import DGLGraph from .graph import DGLGraph
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
from .udf import NodeBatch, EdgeBatch
...@@ -10,7 +10,7 @@ __all__ = ["src_mul_edge", "copy_src", "copy_edge"] ...@@ -10,7 +10,7 @@ __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
"""Base builtin message function class.""" """Base builtin message function class."""
def __call__(self, src, edge): def __call__(self, edges):
"""Regular computation of this builtin. """Regular computation of this builtin.
This will be used when optimization is not available. This will be used when optimization is not available.
...@@ -38,14 +38,10 @@ class BundledMessageFunction(MessageFunction): ...@@ -38,14 +38,10 @@ class BundledMessageFunction(MessageFunction):
return False return False
return True return True
def __call__(self, src, edge): def __call__(self, edges):
ret = None ret = dict()
for fn in self.fn_list: for fn in self.fn_list:
msg = fn(src, edge) msg = fn(edges)
if ret is None:
ret = msg
else:
# ret and msg must be dict
ret.update(msg) ret.update(msg)
return ret return ret
...@@ -83,8 +79,9 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -83,8 +79,9 @@ class SrcMulEdgeMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) \ return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, edges):
ret = self.mul_op(src[self.src_field], edge[self.edge_field]) ret = self.mul_op(edges.src[self.src_field],
edges.data[self.edge_field])
return {self.out_field : ret} return {self.out_field : ret}
def name(self): def name(self):
...@@ -98,8 +95,8 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -98,8 +95,8 @@ class CopySrcMessageFunction(MessageFunction):
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field) return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, edges):
return {self.out_field : src[self.src_field]} return {self.out_field : edges.src[self.src_field]}
def name(self): def name(self):
return "copy_src" return "copy_src"
...@@ -114,15 +111,8 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -114,15 +111,8 @@ class CopyEdgeMessageFunction(MessageFunction):
return False return False
# return _is_spmv_supported_edge_feat(g, self.edge_field) # return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, edges):
if self.edge_field is not None: return {self.out_field : edges.data[self.edge_field]}
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "copy_edge" return "copy_edge"
......
...@@ -8,7 +8,7 @@ __all__ = ["sum", "max"] ...@@ -8,7 +8,7 @@ __all__ = ["sum", "max"]
class ReduceFunction(object): class ReduceFunction(object):
"""Base builtin reduce function class.""" """Base builtin reduce function class."""
def __call__(self, node, msgs): def __call__(self, nodes):
"""Regular computation of this builtin. """Regular computation of this builtin.
This will be used when optimization is not available. This will be used when optimization is not available.
...@@ -35,14 +35,10 @@ class BundledReduceFunction(ReduceFunction): ...@@ -35,14 +35,10 @@ class BundledReduceFunction(ReduceFunction):
return False return False
return True return True
def __call__(self, node, msgs): def __call__(self, nodes):
ret = None ret = dict()
for fn in self.fn_list: for fn in self.fn_list:
rpr = fn(node, msgs) rpr = fn(nodes)
if ret is None:
ret = rpr
else:
# ret and rpr must be dict
ret.update(rpr) ret.update(rpr)
return ret return ret
...@@ -60,8 +56,8 @@ class ReducerFunctionTemplate(ReduceFunction): ...@@ -60,8 +56,8 @@ class ReducerFunctionTemplate(ReduceFunction):
# NOTE: only sum is supported right now. # NOTE: only sum is supported right now.
return self.name == "sum" return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, nodes):
return {self.out_field : self.op(msgs[self.msg_field], 1)} return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}
def name(self): def name(self):
return self.name return self.name
......
...@@ -9,14 +9,16 @@ import dgl ...@@ -9,14 +9,16 @@ import dgl
from .base import ALL, is_all, DGLError, dgl_warning from .base import ALL, is_all, DGLError, dgl_warning
from . import backend as F from . import backend as F
from .backend import Tensor from .backend import Tensor
from .frame import FrameRef, merge_frames from .frame import FrameRef, Frame, merge_frames
from .function.message import BundledMessageFunction from .function.message import BundledMessageFunction
from .function.reducer import BundledReduceFunction from .function.reducer import BundledReduceFunction
from .graph_index import GraphIndex, create_graph_index from .graph_index import GraphIndex, create_graph_index
from . import scheduler from . import scheduler
from .udf import NodeBatch, EdgeBatch
from . import utils from . import utils
from .view import NodeView, EdgeView
__all__ = ['DLGraph'] __all__ = ['DGLGraph']
class DGLGraph(object): class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
...@@ -58,7 +60,6 @@ class DGLGraph(object): ...@@ -58,7 +60,6 @@ class DGLGraph(object):
self._reduce_func = None self._reduce_func = None
self._edge_func = None self._edge_func = None
self._apply_node_func = None self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None): def add_nodes(self, num, reprs=None):
"""Add nodes. """Add nodes.
...@@ -340,67 +341,94 @@ class DGLGraph(object): ...@@ -340,67 +341,94 @@ class DGLGraph(object):
src, dst, _ = self._graph.find_edges(eid) src, dst, _ = self._graph.find_edges(eid)
return src.tousertensor(), dst.tousertensor() return src.tousertensor(), dst.tousertensor()
def in_edges(self, v): def in_edges(self, v, form='uv'):
"""Return the in edges of the node(s). """Return the in edges of the node(s).
Parameters Parameters
---------- ----------
v : int, list, tensor v : int, list, tensor
The node(s). The node(s).
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
Returns Returns
------- -------
tensor A tuple of Tensors (u, v, eid) if form == 'all'
The src nodes. A pair of Tensors (u, v) if form == 'uv'
tensor One Tensor if form == 'eid'
The dst nodes.
tensor
The edge ids.
""" """
v = utils.toindex(v) v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v) src, dst, eid = self._graph.in_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor() if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def out_edges(self, v): def out_edges(self, v, form='uv'):
"""Return the out edges of the node(s). """Return the out edges of the node(s).
Parameters Parameters
---------- ----------
v : int, list, tensor v : int, list, tensor
The node(s). The node(s).
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
Returns Returns
------- -------
tensor A tuple of Tensors (u, v, eid) if form == 'all'
The src nodes. A pair of Tensors (u, v) if form == 'uv'
tensor One Tensor if form == 'eid'
The dst nodes.
tensor
The edge ids.
""" """
v = utils.toindex(v) v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v) src, dst, eid = self._graph.out_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor() if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def edges(self, sorted=False): def all_edges(self, form='uv', sorted=False):
"""Return all the edges. """Return all the edges.
Parameters Parameters
---------- ----------
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
sorted : bool sorted : bool
True if the returned edges are sorted by their src and dst ids. True if the returned edges are sorted by their src and dst ids.
Returns Returns
------- -------
tensor A tuple of Tensors (u, v, eid) if form == 'all'
The src nodes. A pair of Tensors (u, v) if form == 'uv'
tensor One Tensor if form == 'eid'
The dst nodes.
tensor
The edge ids.
""" """
src, dst, eid = self._graph.edges(sorted) src, dst, eid = self._graph.edges(sorted)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor() if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def in_degree(self, v): def in_degree(self, v):
"""Return the in degree of the node. """Return the in degree of the node.
...@@ -430,6 +458,7 @@ class DGLGraph(object): ...@@ -430,6 +458,7 @@ class DGLGraph(object):
tensor tensor
The in degree array. The in degree array.
""" """
v = utils.toindex(v)
return self._graph.in_degrees(v).tousertensor() return self._graph.in_degrees(v).tousertensor()
def out_degree(self, v): def out_degree(self, v):
...@@ -460,6 +489,7 @@ class DGLGraph(object): ...@@ -460,6 +489,7 @@ class DGLGraph(object):
tensor tensor
The out degree array. The out degree array.
""" """
v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor() return self._graph.out_degrees(v).tousertensor()
def to_networkx(self, node_attrs=None, edge_attrs=None): def to_networkx(self, node_attrs=None, edge_attrs=None):
...@@ -581,6 +611,26 @@ class DGLGraph(object): ...@@ -581,6 +611,26 @@ class DGLGraph(object):
""" """
self._edge_frame.set_initializer(initializer) self._edge_frame.set_initializer(initializer)
@property
def nodes(self):
"""Return a node view that can used to set/get feature data."""
return NodeView(self)
@property
def ndata(self):
"""Return the data view of all the nodes."""
return self.nodes[:].data
@property
def edges(self):
"""Return a edges view that can used to set/get feature data."""
return EdgeView(self)
@property
def edata(self):
"""Return the data view of all the edges."""
return self.edges[:].data
def set_n_repr(self, hu, u=ALL, inplace=False): def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
...@@ -660,7 +710,7 @@ class DGLGraph(object): ...@@ -660,7 +710,7 @@ class DGLGraph(object):
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
def set_e_repr(self, he, u=ALL, v=ALL, inplace=False): def set_e_repr(self, he, edges=ALL, inplace=False):
"""Set edge(s) representation. """Set edge(s) representation.
`he` is a dictionary from the feature name to feature tensor. Each tensor `he` is a dictionary from the feature name to feature tensor. Each tensor
...@@ -674,51 +724,29 @@ class DGLGraph(object): ...@@ -674,51 +724,29 @@ class DGLGraph(object):
---------- ----------
he : tensor or dict of tensor he : tensor or dict of tensor
Edge representation. Edge representation.
u : node, container or tensor edges : edges
The source node(s). Edges can be a pair of endpoint nodes (u, v), or a
v : node, container or tensor tensor of edge ids. The default value is all the edges.
The destination node(s).
inplace : bool inplace : bool
True if the update is done inplacely True if the update is done inplacely
""" """
# sanity check # parse argument
if not utils.is_dict_like(he): if is_all(edges):
raise DGLError('Expect dictionary type for feature data.' eid = ALL
' Got "%s" instead.' % type(he)) elif isinstance(edges, tuple):
u_is_all = is_all(u) u, v = edges
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
self.set_e_repr_by_id(he, eid=ALL, inplace=inplace)
else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(he, eid=eid, inplace=inplace) else:
eid = utils.toindex(edges)
def set_e_repr_by_id(self, he, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
Parameters
----------
he : tensor or dict of tensor
Edge representation.
eid : int, container or tensor
The edge id(s).
inplace : bool
True if the update is done inplacely
"""
# sanity check # sanity check
if not utils.is_dict_like(he): if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.' raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he)) ' Got "%s" instead.' % type(he))
if is_all(eid): if is_all(eid):
num_edges = self.number_of_edges() num_edges = self.number_of_edges()
else: else:
...@@ -738,33 +766,39 @@ class DGLGraph(object): ...@@ -738,33 +766,39 @@ class DGLGraph(object):
# update row # update row
self._edge_frame.update_rows(eid, he, inplace=inplace) self._edge_frame.update_rows(eid, he, inplace=inplace)
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, edges=ALL):
"""Get node(s) representation. """Get node(s) representation.
Parameters Parameters
---------- ----------
u : node, container or tensor edges : edges
The source node(s). Edges can be a pair of endpoint nodes (u, v), or a
v : node, container or tensor tensor of edge ids. The default value is all the edges.
The destination node(s).
Returns Returns
------- -------
dict dict
Representation dict Representation dict
""" """
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if len(self.edge_attr_schemes()) == 0: if len(self.edge_attr_schemes()) == 0:
return dict() return dict()
if u_is_all: # parse argument
return self.get_e_repr_by_id(eid=ALL) if is_all(edges):
else: eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
return self.get_e_repr_by_id(eid=eid) else:
eid = utils.toindex(edges)
if is_all(eid):
return dict(self._edge_frame)
else:
eid = utils.toindex(eid)
return self._edge_frame.select_rows(eid)
def pop_e_repr(self, key): def pop_e_repr(self, key):
"""Get and remove the specified edge repr. """Get and remove the specified edge repr.
...@@ -781,27 +815,6 @@ class DGLGraph(object): ...@@ -781,27 +815,6 @@ class DGLGraph(object):
""" """
return self._edge_frame.pop(key) return self._edge_frame.pop(key)
def get_e_repr_by_id(self, eid=ALL):
"""Get edge(s) representation by edge id.
Parameters
----------
eid : int, container or tensor
The edge id(s).
Returns
-------
dict
Representation dict from feature name to feature tensor.
"""
if len(self.edge_attr_schemes()) == 0:
return dict()
if is_all(eid):
return dict(self._edge_frame)
else:
eid = utils.toindex(eid)
return self._edge_frame.select_rows(eid)
def register_edge_func(self, edge_func): def register_edge_func(self, edge_func):
"""Register global edge update function. """Register global edge update function.
...@@ -842,16 +855,6 @@ class DGLGraph(object): ...@@ -842,16 +855,6 @@ class DGLGraph(object):
""" """
self._apply_node_func = apply_node_func self._apply_node_func = apply_node_func
def register_apply_edge_func(self, apply_edge_func):
"""Register global edge apply function.
Parameters
----------
apply_edge_func : callable
Apply function on the edge.
"""
self._apply_edge_func = apply_edge_func
def apply_nodes(self, v=ALL, apply_node_func="default"): def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
...@@ -887,69 +890,24 @@ class DGLGraph(object): ...@@ -887,69 +890,24 @@ class DGLGraph(object):
if reduce_accum is not None: if reduce_accum is not None:
# merge current node_repr with reduce output # merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr) curr_repr = utils.HybridDict(reduce_accum, curr_repr)
new_repr = apply_node_func(curr_repr) nb = NodeBatch(self, v, curr_repr)
new_repr = apply_node_func(nb)
if reduce_accum is not None: if reduce_accum is not None:
# merge new node_repr with reduce output # merge new node_repr with reduce output
reduce_accum.update(new_repr) reduce_accum.update(new_repr)
new_repr = reduce_accum new_repr = reduce_accum
self.set_n_repr(new_repr, v) self.set_n_repr(new_repr, v)
def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None): def send(self, edges=ALL, message_func="default"):
"""Apply the function on edge representations. """Send messages along the given edges.
Applying a None function will be ignored.
Parameters
----------
u : optional, int, iterable of int, tensor
The src node id(s).
v : optional, int, iterable of int, tensor
The dst node id(s).
apply_edge_func : callable
The apply edge function.
eid : None, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
"""
if apply_edge_func == "default":
apply_edge_func = self._apply_edge_func
if not apply_edge_func:
# Skip none function call.
return
if eid is None:
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
new_repr = apply_edge_func(self.get_e_repr_by_id(eid))
self.set_e_repr_by_id(new_repr, eid)
def send(self, u=None, v=None, message_func="default", eid=None):
"""Trigger the message function on edge u->v or eid
The message function should be compatible with following signature:
(node_reprs, edge_reprs) -> message
It computes the representation of a message using the
representations of the source node, and the edge u->v.
All node_reprs and edge_reprs are dictionaries.
The message function can be any of the pre-defined functions
('from_src').
Currently, we require the message functions of consecutive send's to
return the same keys. Otherwise the behavior will be undefined.
TODO(minjie): document on multiple send behavior
Parameters Parameters
---------- ----------
u : optional, node, container or tensor edges : edges, optional
The source node(s). Edges can be a pair of endpoint nodes (u, v), or a
v : optional, node, container or tensor tensor of edge ids. The default value is all the edges.
The destination node(s).
message_func : callable message_func : callable
The message function. The message function.
eid : optional, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
Notes Notes
----- -----
...@@ -961,131 +919,68 @@ class DGLGraph(object): ...@@ -961,131 +919,68 @@ class DGLGraph(object):
assert message_func is not None assert message_func is not None
if isinstance(message_func, (tuple, list)): if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func) message_func = BundledMessageFunction(message_func)
self._batch_send(u, v, eid, message_func)
if is_all(edges):
def _batch_send(self, u, v, eid, message_func): eid = ALL
if is_all(u) and is_all(v) and eid is None: u, v, _ = self._graph.edges()
u, v, eid = self._graph.edges() elif isinstance(edges, tuple):
# call UDF u, v = edges
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs)
elif eid is not None:
eid = utils.toindex(eid)
u, v, _ = self._graph.find_edges(eid)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(u, v) u, v, eid = self._graph.edge_ids(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
# TODO(minjie): Fix these codes in next PR.
"""
new_uv = []
msg_target_rows = []
msg_update_rows = []
msg_append_rows = []
for i, (_u, _v, _eid) in enumerate(zip(u, v, eid)):
if _eid in self._msg_edges:
msg_target_rows.append(self._msg_edges.index(_eid))
msg_update_rows.append(i)
else:
new_uv.append((_u, _v))
self._msg_edges.append(_eid)
msg_append_rows.append(i)
msg_target_rows = utils.toindex(msg_target_rows)
msg_update_rows = utils.toindex(msg_update_rows)
msg_append_rows = utils.toindex(msg_append_rows)
if utils.is_dict_like(msgs):
if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs})
else: else:
if len(msg_target_rows) > 0: eid = utils.toindex(edges)
self._msg_frame.update_rows( u, v, _ = self._graph.find_edges(eid)
msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())}
)
"""
def update_edge(self, u=ALL, v=ALL, edge_func="default", eid=None):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
msgs = message_func(eb)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
It computes the new edge representations using the representations def update_edges(self, edges=ALL, edge_func="default"):
of the source node, target node and the edge itself. """Update features on the given edges.
All node_reprs and edge_reprs are dictionaries.
Parameters Parameters
---------- ----------
u : node, container or tensor edges : edges, optional
The source node(s). Edges can be a pair of endpoint nodes (u, v), or a
v : node, container or tensor tensor of edge ids. The default value is all the edges.
The destination node(s).
edge_func : callable edge_func : callable
The update function. The update function.
eid : optional, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored. Notes
-----
On multigraphs, if u and v are specified, then all the edges
between u and v will be updated.
""" """
if edge_func == "default": if edge_func == "default":
edge_func = self._edge_func edge_func = self._edge_func
assert edge_func is not None assert edge_func is not None
self._batch_update_edge(u, v, eid, edge_func)
if is_all(edges):
def _batch_update_edge(self, u, v, eid, edge_func): eid = ALL
if is_all(u) and is_all(v) and eid is None: u, v, _ = self._graph.edges()
u, v, eid = self._graph.edges() elif isinstance(edges, tuple):
# call the UDF u, v = edges
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr()
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs)
else:
if eid is None:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) # Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(u, v) u, v, eid = self._graph.edge_ids(u, v)
# call the UDF else:
src_reprs = self.get_n_repr(u) eid = utils.toindex(edges)
dst_reprs = self.get_n_repr(v) u, v, _ = self._graph.find_edges(eid)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs) src_data = self.get_n_repr(u)
self.set_e_repr_by_id(new_edge_reprs, eid) edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
self.set_e_repr(edge_func(eb), eid)
def recv(self, def recv(self,
u, u,
...@@ -1093,25 +988,6 @@ class DGLGraph(object): ...@@ -1093,25 +988,6 @@ class DGLGraph(object):
apply_node_func="default"): apply_node_func="default"):
"""Receive and reduce in-coming messages and update representation on node u. """Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
of node u. If no message is found from the predecessors, reduce function
will be skipped.
The reduce function should be compatible with following signature:
(node_reprs, batched_messages) -> node_reprs
It computes the new node representations using the representations
of the in-coming edges (the same concept as messages).
The reduce function can also be pre-defined functions.
An optinoal apply_node function could be specified and should follow following
signature:
node_reprs -> node_reprs
All node_reprs and edge_reprs support tensor and dictionary types.
TODO(minjie): document on zero-in-degree case TODO(minjie): document on zero-in-degree case
TODO(minjie): document on how returned new features are merged with the old features TODO(minjie): document on how returned new features are merged with the old features
TODO(minjie): document on how many times UDFs will be called TODO(minjie): document on how many times UDFs will be called
...@@ -1141,11 +1017,13 @@ class DGLGraph(object): ...@@ -1141,11 +1017,13 @@ class DGLGraph(object):
v_is_all = is_all(v) v_is_all = is_all(v)
if v_is_all: if v_is_all:
v = list(range(self.number_of_nodes())) v = F.arange(0, self.number_of_nodes(), dtype=F.int64)
elif isinstance(v, int):
v = [v]
v = utils.toindex(v)
if len(v) == 0: if len(v) == 0:
# no vertex to be triggered. # no vertex to be triggered.
return return
v = utils.toindex(v)
# degree bucketing # degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v) degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v)
...@@ -1162,7 +1040,7 @@ class DGLGraph(object): ...@@ -1162,7 +1040,7 @@ class DGLGraph(object):
has_zero_degree = True has_zero_degree = True
continue continue
bkt_len = len(v_bkt) bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt) v_data = self.get_n_repr(v_bkt)
uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt) uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt)
in_msgs = self._msg_frame.select_rows(in_msg_ids) in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...). # Reshape the column tensor to (B, Deg, ...).
...@@ -1173,7 +1051,8 @@ class DGLGraph(object): ...@@ -1173,7 +1051,8 @@ class DGLGraph(object):
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.tousertensor()) reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs)) nb = NodeBatch(self, v_bkt, v_data, reshaped_in_msgs)
new_reprs.append(reduce_func(nb))
# TODO(minjie): clear partial messages # TODO(minjie): clear partial messages
self.reset_messages() self.reset_messages()
...@@ -1195,26 +1074,26 @@ class DGLGraph(object): ...@@ -1195,26 +1074,26 @@ class DGLGraph(object):
self.set_n_repr(new_reprs, reordered_v) self.set_n_repr(new_reprs, reordered_v)
def send_and_recv(self, def send_and_recv(self,
u=None, v=None, edges,
message_func="default", message_func="default",
reduce_func="default", reduce_func="default",
apply_node_func="default", apply_node_func="default"):
eid=None): """Send messages along edges and receive them on the targets.
"""Trigger the message function on u->v and update v, or on edge eid
and update the destination nodes.
Parameters Parameters
---------- ----------
u : optional, node, container or tensor edges : edges
The source node(s). Edges can be a pair of endpoint nodes (u, v), or a
v : optional, node, container or tensor tensor of edge ids. The default value is all the edges.
The destination node(s). message_func : callable, optional
message_func : callable The message function. Registered function will be used if not
The message function. specified.
reduce_func : callable reduce_func : callable, optional
The reduce function. The reduce function. Registered function will be used if not
specified.
apply_node_func : callable, optional apply_node_func : callable, optional
The update function. The update function. Registered function will be used if not
specified.
Notes Notes
----- -----
...@@ -1223,69 +1102,58 @@ class DGLGraph(object): ...@@ -1223,69 +1102,58 @@ class DGLGraph(object):
""" """
if message_func == "default": if message_func == "default":
message_func = self._message_func message_func = self._message_func
elif isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if reduce_func == "default": if reduce_func == "default":
reduce_func = self._reduce_func reduce_func = self._reduce_func
elif isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
if eid is None: if isinstance(edges, tuple):
if u is None or v is None: u, v = edges
raise ValueError('u and v must be given if eid is None')
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
if len(u) == 0: if len(u) == 0:
# no edges to be triggered # no edges to be triggered
assert len(v) == 0
return return
unique_v = utils.toindex(F.unique(v.tousertensor()))
if not self.is_multigraph:
executor = scheduler.get_executor( executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v, 'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func) message_func=message_func, reduce_func=reduce_func)
else: else:
eid = utils.toindex(eid)
if len(eid) == 0:
# no edges to be triggered
return
executor = None executor = None
if executor: if executor:
new_reprs = executor.run() accum = executor.run()
unique_v = executor.recv_nodes unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
elif eid is not None:
_, v, _ = self._graph.find_edges(eid)
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(quan): replace with the new DegreeBucketingScheduler
self.send(eid=eid, message_func=message_func)
self.recv(unique_v, reduce_func, apply_node_func)
else: else:
# handle multiple message and reduce func
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
# message func # message func
u, v = utils.edge_broadcasting(u, v) src_data = self.get_n_repr(u)
src_reprs = self.get_n_repr(u) edge_data = self.get_e_repr(eid)
edge_reprs = self.get_e_repr(u, v) dst_data = self.get_n_repr(v)
msgs = message_func(src_reprs, edge_reprs) eb = EdgeBatch(self, (u, v, eid),
msg_frame = FrameRef() src_data, edge_data, dst_data)
msg_frame.append(msgs) msgs = message_func(eb)
msg_frame = FrameRef(Frame(msgs))
# recv with degree bucketing # recv with degree bucketing
executor = scheduler.get_recv_executor(graph=self, executor = scheduler.get_recv_executor(graph=self,
reduce_func=reduce_func, reduce_func=reduce_func,
message_frame=msg_frame, message_frame=msg_frame,
edges=(u, v)) edges=(u, v))
new_reprs = executor.run() assert executor is not None
accum = executor.run()
unique_v = executor.recv_nodes unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs) self._apply_nodes(unique_v, apply_node_func, reduce_accum=accum)
def pull(self, def pull(self,
v, v,
...@@ -1309,7 +1177,7 @@ class DGLGraph(object): ...@@ -1309,7 +1177,7 @@ class DGLGraph(object):
if len(v) == 0: if len(v) == 0:
return return
uu, vv, _ = self._graph.in_edges(v) uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None) self.send_and_recv((uu, vv), message_func, reduce_func, apply_node_func=None)
unique_v = F.unique(v.tousertensor()) unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func) self.apply_nodes(unique_v, apply_node_func)
...@@ -1335,7 +1203,7 @@ class DGLGraph(object): ...@@ -1335,7 +1203,7 @@ class DGLGraph(object):
if len(u) == 0: if len(u) == 0:
return return
uu, vv, _ = self._graph.out_edges(u) uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func, self.send_and_recv((uu, vv), message_func,
reduce_func, apply_node_func) reduce_func, apply_node_func)
def update_all(self, def update_all(self,
...@@ -1366,7 +1234,7 @@ class DGLGraph(object): ...@@ -1366,7 +1234,7 @@ class DGLGraph(object):
new_reprs = executor.run() new_reprs = executor.run()
self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs) self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
else: else:
self.send(ALL, ALL, message_func) self.send(ALL, message_func)
self.recv(ALL, reduce_func, apply_node_func) self.recv(ALL, reduce_func, apply_node_func)
def propagate(self, def propagate(self,
...@@ -1401,7 +1269,7 @@ class DGLGraph(object): ...@@ -1401,7 +1269,7 @@ class DGLGraph(object):
else: else:
# NOTE: the iteration can return multiple edges at each step. # NOTE: the iteration can return multiple edges at each step.
for u, v in traverser: for u, v in traverser:
self.send_and_recv(u, v, self.send_and_recv((u, v),
message_func, reduce_func, apply_node_func) message_func, reduce_func, apply_node_func)
def subgraph(self, nodes): def subgraph(self, nodes):
...@@ -1586,18 +1454,19 @@ class DGLGraph(object): ...@@ -1586,18 +1454,19 @@ class DGLGraph(object):
---------- ----------
predicate : callable predicate : callable
The predicate should take in a dict of tensors whose values The predicate should take in a dict of tensors whose values
are concatenation of edge representations by edge ID (same as are concatenation of edge representations by edge ID,
get_e_repr_by_id()), and return a boolean tensor with N elements and return a boolean tensor with N elements indicating which
indicating which node satisfy the predicate. node satisfy the predicate.
edges : container or tensor edges : edges
The edges to filter on Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns Returns
------- -------
tensor tensor
The filtered edges The filtered edges
""" """
e_repr = self.get_e_repr_by_id(edges) e_repr = self.get_e_repr(edges)
e_mask = predicate(e_repr) e_mask = predicate(e_repr)
if is_all(edges): if is_all(edges):
......
...@@ -5,10 +5,11 @@ import numpy as np ...@@ -5,10 +5,11 @@ import numpy as np
from .base import ALL, DGLError from .base import ALL, DGLError
from . import backend as F from . import backend as F
from collections import defaultdict as ddict
from .function import message as fmsg from .function import message as fmsg
from .function import reducer as fred from .function import reducer as fred
from .udf import NodeBatch, EdgeBatch
from . import utils from . import utils
from collections import defaultdict as ddict
from ._ffi.function import _init_api from ._ffi.function import _init_api
...@@ -176,7 +177,7 @@ class DegreeBucketingExecutor(Executor): ...@@ -176,7 +177,7 @@ class DegreeBucketingExecutor(Executor):
# loop over each bucket # loop over each bucket
# FIXME (lingfan): handle zero-degree case # FIXME (lingfan): handle zero-degree case
for deg, vv, msg_id in zip(self.degrees, self.dsts, self.msg_ids): for deg, vv, msg_id in zip(self.degrees, self.dsts, self.msg_ids):
dst_reprs = self.g.get_n_repr(vv) v_data = self.g.get_n_repr(vv)
in_msgs = self.msg_frame.select_rows(msg_id) in_msgs = self.msg_frame.select_rows(msg_id)
def _reshape_fn(msg): def _reshape_fn(msg):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
...@@ -184,7 +185,8 @@ class DegreeBucketingExecutor(Executor): ...@@ -184,7 +185,8 @@ class DegreeBucketingExecutor(Executor):
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs)) nb = NodeBatch(self.g, vv, v_data, reshaped_in_msgs)
new_reprs.append(self.rfunc(nb))
# Pack all reducer results together # Pack all reducer results together
keys = new_reprs[0].keys() keys = new_reprs[0].keys()
...@@ -320,7 +322,7 @@ class SendRecvExecutor(BasicExecutor): ...@@ -320,7 +322,7 @@ class SendRecvExecutor(BasicExecutor):
@property @property
def edge_repr(self): def edge_repr(self):
if self._edge_repr is None: if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v) self._edge_repr = self.g.get_e_repr((self.u, self.v))
return self._edge_repr return self._edge_repr
def _build_adjmat(self): def _build_adjmat(self):
...@@ -432,8 +434,11 @@ def _create_send_and_recv_exec(graph, **kwargs): ...@@ -432,8 +434,11 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst') dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func') mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func') rfunc = kwargs.pop('reduce_func')
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)): if (isinstance(mfunc, fmsg.BundledMessageFunction)
or isinstance(rfunc, fred.BundledReduceFunction)):
if not isinstance(mfunc, fmsg.BundledMessageFunction):
mfunc = fmsg.BundledMessageFunction(mfunc) mfunc = fmsg.BundledMessageFunction(mfunc)
if not isinstance(rfunc, fred.BundledReduceFunction):
rfunc = fred.BundledReduceFunction(rfunc) rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledSendRecvExecutor exec_cls = BundledSendRecvExecutor
else: else:
......
"""User-defined function related data structures."""
from __future__ import absolute_import
from collections import Mapping
from .base import ALL, is_all
from . import backend as F
from . import utils
class EdgeBatch(object):
"""The object that represents a batch of edges.
Parameters
----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL
src_data : dict of tensors
The src node features
edge_data : dict of tensors
The edge features.
dst_data : dict of tensors
The dst node features
"""
def __init__(self, g, edges, src_data, edge_data, dst_data):
self._g = g
self._edges = edges
self._src_data = src_data
self._edge_data = edge_data
self._dst_data = dst_data
@property
def src(self):
"""Return the feature data of the source nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._src_data
@property
def dst(self):
"""Return the feature data of the destination nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._dst_data
@property
def data(self):
"""Return the edge feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._edge_data
def edges(self):
"""Return the edges contained in this batch.
Returns
-------
tuple of tensors
The edge tuple (u, v, eid).
"""
if is_all(self._edges[2]):
self._edges[2] = utils.toindex(F.arange(
0, self._g.number_of_edges(), dtype=F.int64))
u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
def batch_size(self):
"""Return the number of edges in this edge batch."""
return len(self._edges[0])
def __len__(self):
"""Return the number of edges in this edge batch."""
return self.batch_size()
class NodeBatch(object):
"""The object that represents a batch of nodes.
Parameters
----------
g : DGLGraph
The graph object.
nodes : utils.Index or ALL
The node ids.
data : dict of tensors
The node features
msgs : dict of tensors, optional
The messages.
"""
def __init__(self, g, nodes, data, msgs=None):
self._g = g
self._nodes = nodes
self._data = data
self._msgs = msgs
@property
def data(self):
"""Return the node feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._data
@property
def mailbox(self):
"""Return the received messages.
If no messages received, a None will be returned.
Returns
-------
dict of str to tensors
The message data.
"""
return self._msgs
def nodes(self):
"""Return the nodes contained in this batch.
Returns
-------
tensor
The nodes.
"""
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes(), dtype=F.int64))
return self._nodes.tousertensor()
def batch_size(self):
"""Return the number of nodes in this node batch."""
if is_all(self._nodes):
return self._g.number_of_nodes()
else:
return len(self._nodes)
def __len__(self):
"""Return the number of nodes in this node batch."""
return self.batch_size()
...@@ -20,8 +20,14 @@ class Index(object): ...@@ -20,8 +20,14 @@ class Index(object):
def _dispatch(self, data): def _dispatch(self, data):
"""Store data based on its type.""" """Store data based on its type."""
if isinstance(data, Tensor): if isinstance(data, Tensor):
if not (F.dtype(data) == F.int64 and len(F.shape(data)) == 1): if not (F.dtype(data) == F.int64):
raise ValueError('Index data must be an int64 vector, but got: %s' % str(data))
if len(F.shape(data)) > 1:
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data)) raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
if len(F.shape(data)) == 0:
# a tensor of one int
self._dispatch(int(data))
else:
self._user_tensor_data[F.get_context(data)] = data self._user_tensor_data[F.get_context(data)] = data
elif isinstance(data, nd.NDArray): elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1): if not (data.dtype == 'int64' and len(data.shape) == 1):
...@@ -343,3 +349,18 @@ def reorder(dict_like, index): ...@@ -343,3 +349,18 @@ def reorder(dict_like, index):
idx_ctx = index.tousertensor(F.get_context(val)) idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx) new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict return new_dict
def parse_edges_tuple(edges):
"""Parse the given edges and return the tuple.
Parameters
----------
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns
-------
A tuple of (u, v, eid)
"""
pass
"""Views of DGLGraph."""
from __future__ import absolute_import
from collections import MutableMapping, namedtuple
from .base import ALL, is_all, DGLError
from . import backend as F
from . import utils
NodeSpace = namedtuple('NodeSpace', ['data'])
class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph.
Compared with networkx's NodeView, DGL's NodeView is not
a map. DGL's NodeView supports creating data view from
a given list of nodes.
Parameters
----------
graph : DGLGraph
The graph.
Examples
--------
TBD
"""
__slot__ = '_graph'
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_nodes()
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=NodeDataView(self._graph, ALL))
else:
return NodeSpace(data=NodeDataView(self._graph, nodes))
def __call__(self):
"""Return the nodes."""
return F.arange(0, len(self))
class NodeDataView(MutableMapping):
__slot__ = ['_graph', '_nodes']
def __init__(self, graph, nodes):
self._graph = graph
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr({key : val}, self._nodes)
def __delitem__(self, key):
if not is_all(self._nodes):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
self._graph.pop_n_repr(key)
def __len__(self):
return len(self._graph._node_frame)
def __iter__(self):
return iter(self._graph._node_frame)
def __repr__(self):
data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object):
__slot__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_edges()
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=EdgeDataView(self._graph, ALL))
else:
return EdgeSpace(data=EdgeDataView(self._graph, edges))
def __call__(self, *args, **kwargs):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
class EdgeDataView(MutableMapping):
__slot__ = ['_graph', '_edges']
def __init__(self, graph, edges):
self._graph = graph
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr({key : val}, self._edges)
def __delitem__(self, key):
if not is_all(self._edges):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.edata[key]` instead.')
self._graph.pop_e_repr(key)
def __len__(self):
return len(self._graph._edge_frame)
def __iter__(self):
return iter(self._graph._edge_frame)
def __repr__(self):
data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame})
...@@ -11,20 +11,20 @@ def check_eq(a, b): ...@@ -11,20 +11,20 @@ def check_eq(a, b):
assert a.shape == b.shape assert a.shape == b.shape
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape))) assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def message_func(src, edge): def message_func(edges):
assert len(src['h'].shape) == 2 assert len(edges.src['h'].shape) == 2
assert src['h'].shape[1] == D assert edges.src['h'].shape[1] == D
return {'m' : src['h']} return {'m' : edges.src['h']}
def reduce_func(node, msgs): def reduce_func(nodes):
msgs = msgs['m'] msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape)) reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3 assert len(msgs.shape) == 3
assert msgs.shape[2] == D assert msgs.shape[2] == D
return {'m' : mx.nd.sum(msgs, 1)} return {'m' : mx.nd.sum(msgs, 1)}
def apply_node_func(node): def apply_node_func(nodes):
return {'h' : node['h'] + node['m']} return {'h' : nodes.data['h'] + nodes.data['m']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
...@@ -38,7 +38,7 @@ def generate_graph(grad=False): ...@@ -38,7 +38,7 @@ def generate_graph(grad=False):
ncol = mx.nd.random.normal(shape=(10, D)) ncol = mx.nd.random.normal(shape=(10, D))
if grad: if grad:
ncol.attach_grad() ncol.attach_grad()
g.set_n_repr({'h' : ncol}) g.ndata['h'] = ncol
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -47,15 +47,15 @@ def test_batch_setter_getter(): ...@@ -47,15 +47,15 @@ def test_batch_setter_getter():
g = generate_graph() g = generate_graph()
# set all nodes # set all nodes
g.set_n_repr({'h' : mx.nd.zeros((10, D))}) g.set_n_repr({'h' : mx.nd.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10 assert _pfc(g.ndata['h']) == [0.] * 10
# pop nodes # pop nodes
assert _pfc(g.pop_n_repr('h')) == [0.] * 10 assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0 assert len(g.ndata) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))}) g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes # set partial nodes
u = mx.nd.array([1, 3, 5], dtype='int64') u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u) g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.] assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes # get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64') u = mx.nd.array([1, 2, 3], dtype='int64')
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.] assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
...@@ -81,77 +81,77 @@ def test_batch_setter_getter(): ...@@ -81,77 +81,77 @@ def test_batch_setter_getter():
9, 0, 16 9, 0, 16
''' '''
# set all edges # set all edges
g.set_e_repr({'l' : mx.nd.zeros((17, D))}) g.edata['l'] = mx.nd.zeros((17, D))
assert _pfc(g.get_e_repr()['l']) == [0.] * 17 assert _pfc(g.edata['l']) == [0.] * 17
# pop edges # pop edges
assert _pfc(g.pop_e_repr('l')) == [0.] * 17 assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0 assert len(g.edata) == 0
g.set_e_repr({'l' : mx.nd.zeros((17, D))}) g.edata['l'] = mx.nd.zeros((17, D))
# set partial edges (many-many) # set partial edges (many-many)
u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64') u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64') v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((5, D))}, u, v) g.edges[u, v].data['l'] = mx.nd.ones((5, D))
truth = [0.] * 17 truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1. truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (many-one) # set partial edges (many-one)
u = mx.nd.array([3, 4, 6], dtype='int64') u = mx.nd.array([3, 4, 6], dtype='int64')
v = mx.nd.array([9], dtype='int64') v = mx.nd.array([9], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v) g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[5] = truth[7] = truth[11] = 1. truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (one-many) # set partial edges (one-many)
u = mx.nd.array([0], dtype='int64') u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([4, 5, 6], dtype='int64') v = mx.nd.array([4, 5, 6], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v) g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[6] = truth[8] = truth[10] = 1. truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# get partial edges (many-many) # get partial edges (many-many)
u = mx.nd.array([0, 6, 0], dtype='int64') u = mx.nd.array([0, 6, 0], dtype='int64')
v = mx.nd.array([6, 9, 7], dtype='int64') v = mx.nd.array([6, 9, 7], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one) # get partial edges (many-one)
u = mx.nd.array([5, 6, 7], dtype='int64') u = mx.nd.array([5, 6, 7], dtype='int64')
v = mx.nd.array([9], dtype='int64') v = mx.nd.array([9], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many) # get partial edges (one-many)
u = mx.nd.array([0], dtype='int64') u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([3, 4, 5], dtype='int64') v = mx.nd.array([3, 4, 5], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd(): def test_batch_setter_autograd():
with mx.autograd.record(): with mx.autograd.record():
g = generate_graph(grad=True) g = generate_graph(grad=True)
h1 = g.get_n_repr()['h'] h1 = g.ndata['h']
h1.attach_grad() h1.attach_grad()
# partial set # partial set
v = mx.nd.array([1, 2, 8], dtype='int64') v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D)) hh = mx.nd.zeros((len(v), D))
hh.attach_grad() hh.attach_grad()
g.set_n_repr({'h' : hh}, v) g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h'] h2 = g.ndata['h']
h2.backward(mx.nd.ones((10, D)) * 2) h2.backward(mx.nd.ones((10, D)) * 2)
check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.])) check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.])) check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(src, edge): def _fmsg(edges):
assert src['h'].shape == (5, D) assert edges.src['h'].shape == (5, D)
return {'m' : src['h']} return {'m' : edges.src['h']}
g.register_message_func(_fmsg) g.register_message_func(_fmsg)
# many-many send # many-many send
u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64') u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64') v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v) g.send((u, v))
# one-many send # one-many send
u = mx.nd.array([0], dtype='int64') u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64') v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v) g.send((u, v))
# many-one send # many-one send
u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64') u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
v = mx.nd.array([9], dtype='int64') v = mx.nd.array([9], dtype='int64')
g.send(u, v) g.send((u, v))
def test_batch_recv(): def test_batch_recv():
# basic recv test # basic recv test
...@@ -162,7 +162,7 @@ def test_batch_recv(): ...@@ -162,7 +162,7 @@ def test_batch_recv():
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64') u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64') v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
g.send(u, v) g.send((u, v))
#g.recv(th.unique(v)) #g.recv(th.unique(v))
#assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) #assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
#reduce_msg_shapes.clear() #reduce_msg_shapes.clear()
...@@ -177,7 +177,7 @@ def test_update_routines(): ...@@ -177,7 +177,7 @@ def test_update_routines():
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64') u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64') v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
g.send_and_recv(u, v) g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -208,14 +208,14 @@ def test_reduce_0deg(): ...@@ -208,14 +208,14 @@ def test_reduce_0deg():
g.add_edge(2, 0) g.add_edge(2, 0)
g.add_edge(3, 0) g.add_edge(3, 0)
g.add_edge(4, 0) g.add_edge(4, 0)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : node['h'] + msgs['m'].sum(1)} return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(5, 5)) old_repr = mx.nd.random.normal(shape=(5, 5))
g.set_n_repr({'h': old_repr}) g.set_n_repr({'h': old_repr})
g.update_all(_message, _reduce) g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy()) assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
...@@ -224,25 +224,25 @@ def test_pull_0deg(): ...@@ -224,25 +224,25 @@ def test_pull_0deg():
g = DGLGraph() g = DGLGraph()
g.add_nodes(2) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : msgs['m'].sum(1)} return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(2, 5)) old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr}) g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce) g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
g.pull(1, _message, _reduce) g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
old_repr = mx.nd.random.normal(shape=(2, 5)) old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr}) g.set_n_repr({'h' : old_repr})
g.pull([0, 1], _message, _reduce) g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy()) assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
......
...@@ -10,20 +10,20 @@ def check_eq(a, b): ...@@ -10,20 +10,20 @@ def check_eq(a, b):
assert a.shape == b.shape assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape))) assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(src, edge): def message_func(edges):
assert len(src['h'].shape) == 2 assert len(edges.src['h'].shape) == 2
assert src['h'].shape[1] == D assert edges.src['h'].shape[1] == D
return {'m' : src['h']} return {'m' : edges.src['h']}
def reduce_func(node, msgs): def reduce_func(nodes):
msgs = msgs['m'] msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape)) reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3 assert len(msgs.shape) == 3
assert msgs.shape[2] == D assert msgs.shape[2] == D
return {'accum' : th.sum(msgs, 1)} return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node): def apply_node_func(nodes):
return {'h' : node['h'] + node['accum']} return {'h' : nodes.data['h'] + nodes.data['accum']}
def generate_graph(grad=False): def generate_graph(grad=False):
g = DGLGraph() g = DGLGraph()
...@@ -36,10 +36,11 @@ def generate_graph(grad=False): ...@@ -36,10 +36,11 @@ def generate_graph(grad=False):
# add a back flow from 9 to 0 # add a back flow from 9 to 0
g.add_edge(9, 0) g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad) ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad) ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol}) g.ndata['h'] = ncol
g.edata['w'] = ecol
g.set_n_initializer(lambda shape, dtype : th.zeros(shape)) g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
g.set_e_initializer(lambda shape, dtype : th.zeros(shape))
return g return g
def test_batch_setter_getter(): def test_batch_setter_getter():
...@@ -47,20 +48,20 @@ def test_batch_setter_getter(): ...@@ -47,20 +48,20 @@ def test_batch_setter_getter():
return list(x.numpy()[:,0]) return list(x.numpy()[:,0])
g = generate_graph() g = generate_graph()
# set all nodes # set all nodes
g.set_n_repr({'h' : th.zeros((10, D))}) g.ndata['h'] = th.zeros((10, D))
assert _pfc(g.get_n_repr()['h']) == [0.] * 10 assert th.allclose(g.ndata['h'], th.zeros((10, D)))
# pop nodes # pop nodes
old_len = len(g.get_n_repr()) old_len = len(g.ndata)
assert _pfc(g.pop_n_repr('h')) == [0.] * 10 assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == old_len - 1 assert len(g.ndata) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))}) g.ndata['h'] = th.zeros((10, D))
# set partial nodes # set partial nodes
u = th.tensor([1, 3, 5]) u = th.tensor([1, 3, 5])
g.set_n_repr({'h' : th.ones((3, D))}, u) g.nodes[u].data['h'] = th.ones((3, D))
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.] assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes # get partial nodes
u = th.tensor([1, 2, 3]) u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.] assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
''' '''
s, d, eid s, d, eid
...@@ -83,75 +84,75 @@ def test_batch_setter_getter(): ...@@ -83,75 +84,75 @@ def test_batch_setter_getter():
9, 0, 16 9, 0, 16
''' '''
# set all edges # set all edges
g.set_e_repr({'l' : th.zeros((17, D))}) g.edata['l'] = th.zeros((17, D))
assert _pfc(g.get_e_repr()['l']) == [0.] * 17 assert _pfc(g.edata['l']) == [0.] * 17
# pop edges # pop edges
old_len = len(g.get_e_repr()) old_len = len(g.edata)
assert _pfc(g.pop_e_repr('l')) == [0.] * 17 assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == old_len - 1 assert len(g.edata) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))}) g.edata['l'] = th.zeros((17, D))
# set partial edges (many-many) # set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9]) u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0]) v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr({'l' : th.ones((5, D))}, u, v) g.edges[u, v].data['l'] = th.ones((5, D))
truth = [0.] * 17 truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1. truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (many-one) # set partial edges (many-one)
u = th.tensor([3, 4, 6]) u = th.tensor([3, 4, 6])
v = th.tensor([9]) v = th.tensor([9])
g.set_e_repr({'l' : th.ones((3, D))}, u, v) g.edges[u, v].data['l'] = th.ones((3, D))
truth[5] = truth[7] = truth[11] = 1. truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# set partial edges (one-many) # set partial edges (one-many)
u = th.tensor([0]) u = th.tensor([0])
v = th.tensor([4, 5, 6]) v = th.tensor([4, 5, 6])
g.set_e_repr({'l' : th.ones((3, D))}, u, v) g.edges[u, v].data['l'] = th.ones((3, D))
truth[6] = truth[8] = truth[10] = 1. truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth assert _pfc(g.edata['l']) == truth
# get partial edges (many-many) # get partial edges (many-many)
u = th.tensor([0, 6, 0]) u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7]) v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one) # get partial edges (many-one)
u = th.tensor([5, 6, 7]) u = th.tensor([5, 6, 7])
v = th.tensor([9]) v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many) # get partial edges (one-many)
u = th.tensor([0]) u = th.tensor([0])
v = th.tensor([3, 4, 5]) v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.] assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd(): def test_batch_setter_autograd():
g = generate_graph(grad=True) g = generate_graph(grad=True)
h1 = g.get_n_repr()['h'] h1 = g.ndata['h']
# partial set # partial set
v = th.tensor([1, 2, 8]) v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True) hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr({'h' : hh}, v) g.nodes[v].data['h'] = hh
h2 = g.get_n_repr()['h'] h2 = g.ndata['h']
h2.backward(th.ones((10, D)) * 2) h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.])) check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.])) check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(src, edge): def _fmsg(edges):
assert src['h'].shape == (5, D) assert edges.src['h'].shape == (5, D)
return {'m' : src['h']} return {'m' : edges.src['h']}
g.register_message_func(_fmsg) g.register_message_func(_fmsg)
# many-many send # many-many send
u = th.tensor([0, 0, 0, 0, 0]) u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5]) v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v) g.send((u, v))
# one-many send # one-many send
u = th.tensor([0]) u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5]) v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v) g.send((u, v))
# many-one send # many-one send
u = th.tensor([1, 2, 3, 4, 5]) u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9]) v = th.tensor([9])
g.send(u, v) g.send((u, v))
def test_batch_recv(): def test_batch_recv():
# basic recv test # basic recv test
...@@ -162,11 +163,25 @@ def test_batch_recv(): ...@@ -162,11 +163,25 @@ def test_batch_recv():
u = th.tensor([0, 0, 0, 4, 5, 6]) u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9]) v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
g.send(u, v) g.send((u, v))
g.recv(th.unique(v)) g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_update_edges():
def _upd(edges):
return {'w' : edges.data['w'] * 2}
g = generate_graph()
g.register_edge_func(_upd)
old = g.edata['w']
g.update_edges()
assert th.allclose(old * 2, g.edata['w'])
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_edges((u, v), lambda edges : {'w' : edges.data['w'] * 0.})
eid = g.edge_ids(u, v)
assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))
def test_update_routines(): def test_update_routines():
g = generate_graph() g = generate_graph()
g.register_message_func(message_func) g.register_message_func(message_func)
...@@ -177,7 +192,7 @@ def test_update_routines(): ...@@ -177,7 +192,7 @@ def test_update_routines():
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6]) u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9]) v = th.tensor([1, 2, 3, 9, 9, 9])
g.send_and_recv(u, v) g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
...@@ -208,14 +223,14 @@ def test_reduce_0deg(): ...@@ -208,14 +223,14 @@ def test_reduce_0deg():
g.add_edge(2, 0) g.add_edge(2, 0)
g.add_edge(3, 0) g.add_edge(3, 0)
g.add_edge(4, 0) g.add_edge(4, 0)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : node['h'] + msgs['m'].sum(1)} return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = th.randn(5, 5) old_repr = th.randn(5, 5)
g.set_n_repr({'h' : old_repr}) g.ndata['h'] = old_repr
g.update_all(_message, _reduce) g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[1:], old_repr[1:]) assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0)) assert th.allclose(new_repr[0], old_repr.sum(0))
...@@ -224,26 +239,26 @@ def test_pull_0deg(): ...@@ -224,26 +239,26 @@ def test_pull_0deg():
g = DGLGraph() g = DGLGraph()
g.add_nodes(2) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
def _message(src, edge): def _message(edges):
return {'m' : src['h']} return {'m' : edges.src['h']}
def _reduce(node, msgs): def _reduce(nodes):
return {'h' : msgs['m'].sum(1)} return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr}) g.ndata['h'] = old_repr
g.pull(0, _message, _reduce) g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1]) assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce) g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr}) g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce) g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h'] new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
...@@ -253,27 +268,26 @@ def _disabled_test_send_twice(): ...@@ -253,27 +268,26 @@ def _disabled_test_send_twice():
g.add_nodes(3) g.add_nodes(3)
g.add_edge(0, 1) g.add_edge(0, 1)
g.add_edge(2, 1) g.add_edge(2, 1)
def _message_a(src, edge): def _message_a(edges):
return {'a': src['a']} return {'a': edges.src['a']}
def _message_b(src, edge): def _message_b(edges):
return {'a': src['a'] * 3} return {'a': edges.src['a'] * 3}
def _reduce(node, msgs): def _reduce(nodes):
assert msgs is not None return {'a': nodes.mailbox['a'].max(1)[0]}
return {'a': msgs['a'].max(1)[0]}
old_repr = th.randn(3, 5) old_repr = th.randn(3, 5)
g.set_n_repr({'a': old_repr}) g.ndata['a'] = old_repr
g.send(0, 1, _message_a) g.send((0, 1), _message_a)
g.send(0, 1, _message_b) g.send((0, 1), _message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr[0] * 3) assert th.allclose(new_repr[1], old_repr[0] * 3)
g.set_n_repr({'a': old_repr}) g.ndata['a'] = old_repr
g.send(0, 1, _message_a) g.send((0, 1), _message_a)
g.send(2, 1, _message_b) g.send((2, 1), _message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0]) assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_multigraph(): def test_send_multigraph():
...@@ -284,64 +298,63 @@ def test_send_multigraph(): ...@@ -284,64 +298,63 @@ def test_send_multigraph():
g.add_edge(0, 1) g.add_edge(0, 1)
g.add_edge(2, 1) g.add_edge(2, 1)
def _message_a(src, edge): def _message_a(edges):
return {'a': edge['a']} return {'a': edges.data['a']}
def _message_b(src, edge): def _message_b(edges):
return {'a': edge['a'] * 3} return {'a': edges.data['a'] * 3}
def _reduce(node, msgs): def _reduce(nodes):
assert msgs is not None return {'a': nodes.mailbox['a'].max(1)[0]}
return {'a': msgs['a'].max(1)[0]}
def answer(*args): def answer(*args):
return th.stack(args, 0).max(0)[0] return th.stack(args, 0).max(0)[0]
# send by eid # send by eid
old_repr = th.randn(4, 5) old_repr = th.randn(4, 5)
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(eid=[0, 2], message_func=_message_a) g.send([0, 2], message_func=_message_a)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2])) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(eid=[0, 2, 3], message_func=_message_a) g.send([0, 2, 3], message_func=_message_a)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
# send on multigraph # send on multigraph
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send([0, 2], [1, 1], _message_a) g.send(([0, 2], [1, 1]), _message_a)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr.max(0)[0]) assert th.allclose(new_repr[1], old_repr.max(0)[0])
# consecutive send and send_on # consecutive send and send_on
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(2, 1, _message_a) g.send((2, 1), _message_a)
g.send(eid=[0, 1], message_func=_message_b) g.send([0, 1], message_func=_message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3])) assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
# consecutive send_on # consecutive send_on
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send(eid=0, message_func=_message_a) g.send(0, message_func=_message_a)
g.send(eid=1, message_func=_message_b) g.send(1, message_func=_message_b)
g.recv([1], _reduce) g.recv(1, _reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3)) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
# send_and_recv_on # send_and_recv_on
g.set_n_repr({'a': th.zeros(3, 5)}) g.ndata['a'] = th.zeros(3, 5)
g.set_e_repr({'a': old_repr}) g.edata['a'] = old_repr
g.send_and_recv(eid=[0, 2, 3], message_func=_message_a, reduce_func=_reduce) g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.get_n_repr()['a'] new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3])) assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5)) assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))
...@@ -353,29 +366,25 @@ def test_dynamic_addition(): ...@@ -353,29 +366,25 @@ def test_dynamic_addition():
# Test node addition # Test node addition
g.add_nodes(N) g.add_nodes(N)
g.set_n_repr({'h1': th.randn(N, D), g.ndata.update({'h1': th.randn(N, D),
'h2': th.randn(N, D)}) 'h2': th.randn(N, D)})
g.add_nodes(3) g.add_nodes(3)
n_repr = g.get_n_repr() assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
assert n_repr['h1'].shape[0] == n_repr['h2'].shape[0] == N + 3
# Test edge addition # Test edge addition
g.add_edge(0, 1) g.add_edge(0, 1)
g.add_edge(1, 0) g.add_edge(1, 0)
g.set_e_repr({'h1': th.randn(2, D), g.edata.update({'h1': th.randn(2, D),
'h2': th.randn(2, D)}) 'h2': th.randn(2, D)})
e_repr = g.get_e_repr() assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 2
g.add_edges([0, 2], [2, 0]) g.add_edges([0, 2], [2, 0])
e_repr = g.get_e_repr() g.edata['h1'] = th.randn(4, D)
g.set_e_repr({'h1': th.randn(4, D)}) assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 4
g.add_edge(1, 2) g.add_edge(1, 2)
g.set_e_repr_by_id({'h1': th.randn(1, D)}, eid=4) g.edges[4].data['h1'] = th.randn(1, D)
e_repr = g.get_e_repr() assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 5
if __name__ == '__main__': if __name__ == '__main__':
...@@ -383,6 +392,7 @@ if __name__ == '__main__': ...@@ -383,6 +392,7 @@ if __name__ == '__main__':
test_batch_setter_autograd() test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_edges()
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
......
...@@ -18,8 +18,8 @@ def tree1(): ...@@ -18,8 +18,8 @@ def tree1():
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(1, 0) g.add_edge(1, 0)
g.add_edge(2, 0) g.add_edge(2, 0)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])}) g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.set_e_repr({'h' : th.randn(4, 10)}) g.edata['h'] = th.randn(4, 10)
return g return g
def tree2(): def tree2():
...@@ -37,17 +37,17 @@ def tree2(): ...@@ -37,17 +37,17 @@ def tree2():
g.add_edge(0, 4) g.add_edge(0, 4)
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(3, 1) g.add_edge(3, 1)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])}) g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.set_e_repr({'h' : th.randn(4, 10)}) g.edata['h'] = th.randn(4, 10)
return g return g
def test_batch_unbatch(): def test_batch_unbatch():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
n1 = t1.get_n_repr()['h'] n1 = t1.ndata['h']
n2 = t2.get_n_repr()['h'] n2 = t2.ndata['h']
e1 = t1.get_e_repr()['h'] e1 = t1.edata['h']
e2 = t2.get_e_repr()['h'] e2 = t2.edata['h']
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10 assert bg.number_of_nodes() == 10
...@@ -57,10 +57,10 @@ def test_batch_unbatch(): ...@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert bg.batch_num_edges == [4, 4] assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg) tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr()['h'], tt1.get_n_repr()['h']) assert th.allclose(t1.ndata['h'], tt1.ndata['h'])
assert th.allclose(t1.get_e_repr()['h'], tt1.get_e_repr()['h']) assert th.allclose(t1.edata['h'], tt1.edata['h'])
assert th.allclose(t2.get_n_repr()['h'], tt2.get_n_repr()['h']) assert th.allclose(t2.ndata['h'], tt2.ndata['h'])
assert th.allclose(t2.get_e_repr()['h'], tt2.get_e_repr()['h']) assert th.allclose(t2.edata['h'], tt2.edata['h'])
def test_batch_unbatch1(): def test_batch_unbatch1():
t1 = tree1() t1 = tree1()
...@@ -74,29 +74,29 @@ def test_batch_unbatch1(): ...@@ -74,29 +74,29 @@ def test_batch_unbatch1():
assert b2.batch_num_edges == [4, 4, 4] assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2) s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr()['h'], s1.get_n_repr()['h']) assert th.allclose(t2.ndata['h'], s1.ndata['h'])
assert th.allclose(t2.get_e_repr()['h'], s1.get_e_repr()['h']) assert th.allclose(t2.edata['h'], s1.edata['h'])
assert th.allclose(t1.get_n_repr()['h'], s2.get_n_repr()['h']) assert th.allclose(t1.ndata['h'], s2.ndata['h'])
assert th.allclose(t1.get_e_repr()['h'], s2.get_e_repr()['h']) assert th.allclose(t1.edata['h'], s2.edata['h'])
assert th.allclose(t2.get_n_repr()['h'], s3.get_n_repr()['h']) assert th.allclose(t2.ndata['h'], s3.ndata['h'])
assert th.allclose(t2.get_e_repr()['h'], s3.get_e_repr()['h']) assert th.allclose(t2.edata['h'], s3.edata['h'])
def test_batch_sendrecv(): def test_batch_sendrecv():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']}) bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)}) bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
u = [3, 4, 2 + 5, 0 + 5] u = [3, 4, 2 + 5, 0 + 5]
v = [1, 1, 4 + 5, 4 + 5] v = [1, 1, 4 + 5, 4 + 5]
bg.send(u, v) bg.send((u, v))
bg.recv(v) bg.recv(v)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][1] == 7 assert t1.ndata['h'][1] == 7
assert t2.get_n_repr()['h'][4] == 2 assert t2.ndata['h'][4] == 2
def test_batch_propagate(): def test_batch_propagate():
...@@ -104,8 +104,8 @@ def test_batch_propagate(): ...@@ -104,8 +104,8 @@ def test_batch_propagate():
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']}) bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)}) bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
# get leaves. # get leaves.
order = [] order = []
...@@ -123,23 +123,23 @@ def test_batch_propagate(): ...@@ -123,23 +123,23 @@ def test_batch_propagate():
bg.propagate(traverser=order) bg.propagate(traverser=order)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][0] == 9 assert t1.ndata['h'][0] == 9
assert t2.get_n_repr()['h'][1] == 5 assert t2.ndata['h'][1] == 5
def test_batched_edge_ordering(): def test_batched_edge_ordering():
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
g1.add_nodes(6) g1.add_nodes(6)
g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1]) g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
e1 = th.randn(5, 10) e1 = th.randn(5, 10)
g1.set_e_repr({'h' : e1}) g1.edata['h'] = e1
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g2.add_nodes(6) g2.add_nodes(6)
g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0]) g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
e2 = th.randn(6, 10) e2 = th.randn(6, 10)
g2.set_e_repr({'h' : e2}) g2.edata['h'] = e2
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
r1 = g.get_e_repr()['h'][g.edge_id(4, 5)] r1 = g.edata['h'][g.edge_id(4, 5)]
r2 = g1.get_e_repr()['h'][g1.edge_id(4, 5)] r2 = g1.edata['h'][g1.edge_id(4, 5)]
assert th.equal(r1, r2) assert th.equal(r1, r2)
def test_batch_no_edge(): def test_batch_no_edge():
......
...@@ -12,8 +12,8 @@ def test_filter(): ...@@ -12,8 +12,8 @@ def test_filter():
n_repr[[1, 3]] = 1 n_repr[[1, 3]] = 1
e_repr[[1, 3]] = 1 e_repr[[1, 3]] = 1
g.set_n_repr({'a': n_repr}) g.ndata['a'] = n_repr
g.set_e_repr({'a': e_repr}) g.edata['a'] = e_repr
def predicate(r): def predicate(r):
return r['a'].max(1)[0] > 0 return r['a'].max(1)[0] > 0
......
...@@ -6,7 +6,7 @@ def generate_graph(): ...@@ -6,7 +6,7 @@ def generate_graph():
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes. g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float) h = th.arange(1, 11, dtype=th.float)
g.set_n_repr({'h': h}) g.ndata['h'] = h
# create a graph where 0 is the source and 9 is the sink # create a graph where 0 is the source and 9 is the sink
for i in range(1, 9): for i in range(1, 9):
g.add_edge(0, i) g.add_edge(0, i)
...@@ -15,29 +15,11 @@ def generate_graph(): ...@@ -15,29 +15,11 @@ def generate_graph():
g.add_edge(9, 0) g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\ h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.]) 1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr({'h' : h}) g.edata['h'] = h
return g return g
def generate_graph1(): def reducer_both(nodes):
"""graph with anonymous repr""" return {'h' : th.sum(nodes.mailbox['m'], 1)}
g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float)
h = th.arange(1, 11, dtype=th.float)
g.set_n_repr(h)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr(h)
return g
def reducer_both(node, msgs):
return {'h' : th.sum(msgs['m'], 1)}
def test_copy_src(): def test_copy_src():
# copy_src with both fields # copy_src with both fields
...@@ -45,7 +27,7 @@ def test_copy_src(): ...@@ -45,7 +27,7 @@ def test_copy_src():
g.register_message_func(fn.copy_src(src='h', out='m')) g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge(): def test_copy_edge():
...@@ -54,7 +36,7 @@ def test_copy_edge(): ...@@ -54,7 +36,7 @@ def test_copy_edge():
g.register_message_func(fn.copy_edge(edge='h', out='m')) g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge(): def test_src_mul_edge():
...@@ -63,7 +45,7 @@ def test_src_mul_edge(): ...@@ -63,7 +45,7 @@ def test_src_mul_edge():
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m')) g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both) g.register_reduce_func(reducer_both)
g.update_all() g.update_all()
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.ndata['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment