"third_party/libxsmm/scripts/libxsmm_source.sh" did not exist on "3359c1f1af3d67082d590a17a1663169e5fe29b3"
Unverified Commit 68ec6247 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API][Doc] API change & basic tutorials (#113)

* Add SH tutorials

* setup sphinx-gallery; work on graph tutorial

* draft dglgraph tutorial

* update readme to include document url

* rm obsolete file

* Draft the message passing tutorial

* Capsule code (#102)

* add capsule example

* clean code

* better naming

* better naming

* [GCN]tutorial scaffold

* fix capsule example code

* remove previous capsule example code

* graph struc edit

* modified:   2_graph.py

* update doc of capsule

* update capsule docs

* update capsule docs

* add msg passing prime

* GCN-GAT tutorial Section 1 and 2

* comment for API improvement

* section 3

* Tutorial API change (#115)

* change the API as discusses; toy example

* enable the new set/get syntax

* fixed pytorch utest

* fixed gcn example

* fixed gat example

* fixed mx utests

* fix mx utest

* delete apply edges; add utest for update_edges

* small change on toy example

* fix utest

* fix out in degrees bug

* update pagerank example and add it to CI

* add delitem for dataview

* make edges() return form that is compatible with send/update_edges etc

* fix index bug when the given data is one-int-tensor

* fix doc
parent 2ecd2b23
......@@ -17,7 +17,7 @@ def build_dgl() {
}
dir ('build') {
sh 'cmake ..'
sh 'make -j$(nproc)'
sh 'make -j4'
}
}
......
......@@ -46,6 +46,7 @@ extensions = [
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery',
]
......
from __future__ import division
import networkx as nx
from dgl.graph import DGLGraph
DAMP = 0.85
N = 100
K = 10
def message_func(src, dst, edge):
return src['pv'] / src['deg']
def update_func(node, accum):
pv = (1 - DAMP) / N + DAMP * accum
return {'pv' : pv}
def compute_pagerank(g):
g = DGLGraph(g)
print(g.number_of_edges(), g.number_of_nodes())
g.register_message_func(message_func)
g.register_update_func(update_func)
g.register_reduce_func('sum')
# init pv value
for n in g.nodes():
g.node[n]['pv'] = 1 / N
g.node[n]['deg'] = g.out_degree(n)
# pagerank
for k in range(K):
g.update_all()
return [g.node[n]['pv'] for n in g.nodes()]
if __name__ == '__main__':
g = nx.erdos_renyi_graph(N, 0.05)
pv = compute_pagerank(g)
print(pv)
......@@ -16,18 +16,18 @@ import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gat_message(src, edge):
return {'ft' : src['ft'], 'a2' : src['a2']}
def gat_message(edges):
return {'ft' : edges.src['ft'], 'a2' : edges.src['a2']}
class GATReduce(nn.Module):
def __init__(self, attn_drop):
super(GATReduce, self).__init__()
self.attn_drop = attn_drop
def forward(self, node, msgs):
a1 = torch.unsqueeze(node['a1'], 1) # shape (B, 1, 1)
a2 = msgs['a2'] # shape (B, deg, 1)
ft = msgs['ft'] # shape (B, deg, D)
def forward(self, nodes):
a1 = torch.unsqueeze(nodes.data['a1'], 1) # shape (B, 1, 1)
a2 = nodes.mailbox['a2'] # shape (B, deg, 1)
ft = nodes.mailbox['ft'] # shape (B, deg, D)
# attention
a = a1 + a2 # shape (B, deg, 1)
e = F.softmax(F.leaky_relu(a), dim=1)
......@@ -46,13 +46,13 @@ class GATFinalize(nn.Module):
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node):
ret = node['accum']
def forward(self, nodes):
ret = nodes.data['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
ret = self.residual_fc(nodes.data['h']) + ret
else:
ret = node['h'] + ret
ret = nodes.data['h'] + ret
return {'head%d' % self.headid : self.activation(ret)}
class GATPrepare(nn.Module):
......@@ -120,7 +120,7 @@ class GAT(nn.Module):
for hid in range(self.num_heads):
i = l * self.num_heads + hid
# prepare
self.g.set_n_repr(self.prp[i](last))
self.g.ndata.update(self.prp[i](last))
# message passing
self.g.update_all(gat_message, self.red[i], self.fnl[i])
# merge all the heads
......@@ -128,7 +128,7 @@ class GAT(nn.Module):
[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)],
dim=1)
# output projection
self.g.set_n_repr(self.prp[-1](last))
self.g.ndata.update(self.prp[-1](last))
self.g.update_all(gat_message, self.red[-1], self.fnl[-1])
return self.g.pop_n_repr('head0')
......
......@@ -14,24 +14,22 @@ import torch.nn.functional as F
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
def gcn_msg(src, edge):
return {'m' : src['h']}
def gcn_msg(edges):
return {'m' : edges.src['h']}
def gcn_reduce(node, msgs):
return {'h' : torch.sum(msgs['m'], 1)}
def gcn_reduce(nodes):
return {'h' : torch.sum(nodes.mailbox['m'], 1)}
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node['h'])
def forward(self, nodes):
h = self.linear(nodes.data['h'])
if self.activation:
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
......@@ -62,13 +60,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr({'h' : features})
self.g.ndata['h'] = features
for layer in self.layers:
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h')
......
......@@ -22,8 +22,8 @@ class NodeApplyModule(nn.Module):
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node['h'])
def forward(self, nodes):
h = self.linear(nodes.data['h'])
if self.activation:
h = self.activation(h)
......@@ -57,13 +57,13 @@ class GCN(nn.Module):
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr({'h' : features})
self.g.ndata['h'] = features
for layer in self.layers:
# apply dropout
if self.dropout:
self.g.apply_nodes(apply_node_func=
lambda node: {'h': self.dropout(node['h'])})
lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
......
import networkx as nx
import torch
import dgl
import dgl.function as fn
N = 100
g = nx.nx.erdos_renyi_graph(N, 0.05)
g = dgl.DGLGraph(g)
DAMP = 0.85
K = 10
def compute_pagerank(g):
g.ndata['pv'] = torch.ones(N) / N
degrees = g.out_degrees(g.nodes()).type(torch.float32)
for k in range(K):
g.ndata['pv'] = g.ndata['pv'] / degrees
g.update_all(message_func=fn.copy_src(src='pv', out='m'),
reduce_func=fn.sum(msg='m', out='pv'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['pv']
return g.ndata['pv']
pv = compute_pagerank(g)
print(pv)
......@@ -11,3 +11,4 @@ from .base import ALL
from .batched_graph import *
from .graph import DGLGraph
from .subgraph import DGLSubGraph
from .udf import NodeBatch, EdgeBatch
......@@ -10,7 +10,7 @@ __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object):
"""Base builtin message function class."""
def __call__(self, src, edge):
def __call__(self, edges):
"""Regular computation of this builtin.
This will be used when optimization is not available.
......@@ -38,14 +38,10 @@ class BundledMessageFunction(MessageFunction):
return False
return True
def __call__(self, src, edge):
ret = None
def __call__(self, edges):
ret = dict()
for fn in self.fn_list:
msg = fn(src, edge)
if ret is None:
ret = msg
else:
# ret and msg must be dict
msg = fn(edges)
ret.update(msg)
return ret
......@@ -83,8 +79,9 @@ class SrcMulEdgeMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) \
and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge):
ret = self.mul_op(src[self.src_field], edge[self.edge_field])
def __call__(self, edges):
ret = self.mul_op(edges.src[self.src_field],
edges.data[self.edge_field])
return {self.out_field : ret}
def name(self):
......@@ -98,8 +95,8 @@ class CopySrcMessageFunction(MessageFunction):
def is_spmv_supported(self, g):
return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge):
return {self.out_field : src[self.src_field]}
def __call__(self, edges):
return {self.out_field : edges.src[self.src_field]}
def name(self):
return "copy_src"
......@@ -114,15 +111,8 @@ class CopyEdgeMessageFunction(MessageFunction):
return False
# return _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge):
if self.edge_field is not None:
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def __call__(self, edges):
return {self.out_field : edges.data[self.edge_field]}
def name(self):
return "copy_edge"
......
......@@ -8,7 +8,7 @@ __all__ = ["sum", "max"]
class ReduceFunction(object):
"""Base builtin reduce function class."""
def __call__(self, node, msgs):
def __call__(self, nodes):
"""Regular computation of this builtin.
This will be used when optimization is not available.
......@@ -35,14 +35,10 @@ class BundledReduceFunction(ReduceFunction):
return False
return True
def __call__(self, node, msgs):
ret = None
def __call__(self, nodes):
ret = dict()
for fn in self.fn_list:
rpr = fn(node, msgs)
if ret is None:
ret = rpr
else:
# ret and rpr must be dict
rpr = fn(nodes)
ret.update(rpr)
return ret
......@@ -60,8 +56,8 @@ class ReducerFunctionTemplate(ReduceFunction):
# NOTE: only sum is supported right now.
return self.name == "sum"
def __call__(self, node, msgs):
return {self.out_field : self.op(msgs[self.msg_field], 1)}
def __call__(self, nodes):
return {self.out_field : self.op(nodes.mailbox[self.msg_field], 1)}
def name(self):
return self.name
......
......@@ -9,14 +9,16 @@ import dgl
from .base import ALL, is_all, DGLError, dgl_warning
from . import backend as F
from .backend import Tensor
from .frame import FrameRef, merge_frames
from .frame import FrameRef, Frame, merge_frames
from .function.message import BundledMessageFunction
from .function.reducer import BundledReduceFunction
from .graph_index import GraphIndex, create_graph_index
from . import scheduler
from .udf import NodeBatch, EdgeBatch
from . import utils
from .view import NodeView, EdgeView
__all__ = ['DLGraph']
__all__ = ['DGLGraph']
class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs.
......@@ -58,7 +60,6 @@ class DGLGraph(object):
self._reduce_func = None
self._edge_func = None
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
"""Add nodes.
......@@ -340,67 +341,94 @@ class DGLGraph(object):
src, dst, _ = self._graph.find_edges(eid)
return src.tousertensor(), dst.tousertensor()
def in_edges(self, v):
def in_edges(self, v, form='uv'):
"""Return the in edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
A tuple of Tensors (u, v, eid) if form == 'all'
A pair of Tensors (u, v) if form == 'uv'
One Tensor if form == 'eid'
"""
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def out_edges(self, v):
def out_edges(self, v, form='uv'):
"""Return the out edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
A tuple of Tensors (u, v, eid) if form == 'all'
A pair of Tensors (u, v) if form == 'uv'
One Tensor if form == 'eid'
"""
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def edges(self, sorted=False):
def all_edges(self, form='uv', sorted=False):
"""Return all the edges.
Parameters
----------
form : str, optional
The return form. Currently support:
- 'all' : a tuple (u, v, eid)
- 'uv' : a pair (u, v), default
- 'eid' : one eid tensor
sorted : bool
True if the returned edges are sorted by their src and dst ids.
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
A tuple of Tensors (u, v, eid) if form == 'all'
A pair of Tensors (u, v) if form == 'uv'
One Tensor if form == 'eid'
"""
src, dst, eid = self._graph.edges(sorted)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
if form == 'all':
return (src.tousertensor(), dst.tousertensor(), eid.tousertensor())
elif form == 'uv':
return (src.tousertensor(), dst.tousertensor())
elif form == 'eid':
return eid.tousertensor()
else:
raise DGLError('Invalid form:', form)
def in_degree(self, v):
"""Return the in degree of the node.
......@@ -430,6 +458,7 @@ class DGLGraph(object):
tensor
The in degree array.
"""
v = utils.toindex(v)
return self._graph.in_degrees(v).tousertensor()
def out_degree(self, v):
......@@ -460,6 +489,7 @@ class DGLGraph(object):
tensor
The out degree array.
"""
v = utils.toindex(v)
return self._graph.out_degrees(v).tousertensor()
def to_networkx(self, node_attrs=None, edge_attrs=None):
......@@ -581,6 +611,26 @@ class DGLGraph(object):
"""
self._edge_frame.set_initializer(initializer)
@property
def nodes(self):
"""Return a node view that can used to set/get feature data."""
return NodeView(self)
@property
def ndata(self):
"""Return the data view of all the nodes."""
return self.nodes[:].data
@property
def edges(self):
"""Return a edges view that can used to set/get feature data."""
return EdgeView(self)
@property
def edata(self):
"""Return the data view of all the edges."""
return self.edges[:].data
def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation.
......@@ -660,7 +710,7 @@ class DGLGraph(object):
"""
return self._node_frame.pop(key)
def set_e_repr(self, he, u=ALL, v=ALL, inplace=False):
def set_e_repr(self, he, edges=ALL, inplace=False):
"""Set edge(s) representation.
`he` is a dictionary from the feature name to feature tensor. Each tensor
......@@ -674,51 +724,29 @@ class DGLGraph(object):
----------
he : tensor or dict of tensor
Edge representation.
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
inplace : bool
True if the update is done inplacely
"""
# sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if u_is_all:
self.set_e_repr_by_id(he, eid=ALL, inplace=inplace)
else:
# parse argument
if is_all(edges):
eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(he, eid=eid, inplace=inplace)
def set_e_repr_by_id(self, he, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace
flag is true.
else:
eid = utils.toindex(edges)
Parameters
----------
he : tensor or dict of tensor
Edge representation.
eid : int, container or tensor
The edge id(s).
inplace : bool
True if the update is done inplacely
"""
# sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
if is_all(eid):
num_edges = self.number_of_edges()
else:
......@@ -738,33 +766,39 @@ class DGLGraph(object):
# update row
self._edge_frame.update_rows(eid, he, inplace=inplace)
def get_e_repr(self, u=ALL, v=ALL):
def get_e_repr(self, edges=ALL):
"""Get node(s) representation.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns
-------
dict
Representation dict
"""
u_is_all = is_all(u)
v_is_all = is_all(v)
assert u_is_all == v_is_all
if len(self.edge_attr_schemes()) == 0:
return dict()
if u_is_all:
return self.get_e_repr_by_id(eid=ALL)
else:
# parse argument
if is_all(edges):
eid = ALL
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
_, _, eid = self._graph.edge_ids(u, v)
return self.get_e_repr_by_id(eid=eid)
else:
eid = utils.toindex(edges)
if is_all(eid):
return dict(self._edge_frame)
else:
eid = utils.toindex(eid)
return self._edge_frame.select_rows(eid)
def pop_e_repr(self, key):
"""Get and remove the specified edge repr.
......@@ -781,27 +815,6 @@ class DGLGraph(object):
"""
return self._edge_frame.pop(key)
def get_e_repr_by_id(self, eid=ALL):
"""Get edge(s) representation by edge id.
Parameters
----------
eid : int, container or tensor
The edge id(s).
Returns
-------
dict
Representation dict from feature name to feature tensor.
"""
if len(self.edge_attr_schemes()) == 0:
return dict()
if is_all(eid):
return dict(self._edge_frame)
else:
eid = utils.toindex(eid)
return self._edge_frame.select_rows(eid)
def register_edge_func(self, edge_func):
"""Register global edge update function.
......@@ -842,16 +855,6 @@ class DGLGraph(object):
"""
self._apply_node_func = apply_node_func
def register_apply_edge_func(self, apply_edge_func):
"""Register global edge apply function.
Parameters
----------
apply_edge_func : callable
Apply function on the edge.
"""
self._apply_edge_func = apply_edge_func
def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations.
......@@ -887,69 +890,24 @@ class DGLGraph(object):
if reduce_accum is not None:
# merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr)
new_repr = apply_node_func(curr_repr)
nb = NodeBatch(self, v, curr_repr)
new_repr = apply_node_func(nb)
if reduce_accum is not None:
# merge new node_repr with reduce output
reduce_accum.update(new_repr)
new_repr = reduce_accum
self.set_n_repr(new_repr, v)
def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None):
"""Apply the function on edge representations.
Applying a None function will be ignored.
Parameters
----------
u : optional, int, iterable of int, tensor
The src node id(s).
v : optional, int, iterable of int, tensor
The dst node id(s).
apply_edge_func : callable
The apply edge function.
eid : None, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
"""
if apply_edge_func == "default":
apply_edge_func = self._apply_edge_func
if not apply_edge_func:
# Skip none function call.
return
if eid is None:
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
new_repr = apply_edge_func(self.get_e_repr_by_id(eid))
self.set_e_repr_by_id(new_repr, eid)
def send(self, u=None, v=None, message_func="default", eid=None):
"""Trigger the message function on edge u->v or eid
The message function should be compatible with following signature:
(node_reprs, edge_reprs) -> message
It computes the representation of a message using the
representations of the source node, and the edge u->v.
All node_reprs and edge_reprs are dictionaries.
The message function can be any of the pre-defined functions
('from_src').
Currently, we require the message functions of consecutive send's to
return the same keys. Otherwise the behavior will be undefined.
TODO(minjie): document on multiple send behavior
def send(self, edges=ALL, message_func="default"):
"""Send messages along the given edges.
Parameters
----------
u : optional, node, container or tensor
The source node(s).
v : optional, node, container or tensor
The destination node(s).
edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
message_func : callable
The message function.
eid : optional, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
Notes
-----
......@@ -961,131 +919,68 @@ class DGLGraph(object):
assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
self._batch_send(u, v, eid, message_func)
def _batch_send(self, u, v, eid, message_func):
if is_all(u) and is_all(v) and eid is None:
u, v, eid = self._graph.edges()
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs)
elif eid is not None:
eid = utils.toindex(eid)
u, v, _ = self._graph.find_edges(eid)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
else:
if is_all(edges):
eid = ALL
u, v, _ = self._graph.edges()
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
# TODO(minjie): Fix these codes in next PR.
"""
new_uv = []
msg_target_rows = []
msg_update_rows = []
msg_append_rows = []
for i, (_u, _v, _eid) in enumerate(zip(u, v, eid)):
if _eid in self._msg_edges:
msg_target_rows.append(self._msg_edges.index(_eid))
msg_update_rows.append(i)
else:
new_uv.append((_u, _v))
self._msg_edges.append(_eid)
msg_append_rows.append(i)
msg_target_rows = utils.toindex(msg_target_rows)
msg_update_rows = utils.toindex(msg_update_rows)
msg_append_rows = utils.toindex(msg_append_rows)
if utils.is_dict_like(msgs):
if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
msg_target_rows,
{k: F.gather_row(msgs[k], msg_update_rows.tousertensor())
for k in msgs},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{k: F.gather_row(msgs[k], msg_append_rows.tousertensor())
for k in msgs})
else:
if len(msg_target_rows) > 0:
self._msg_frame.update_rows(
msg_target_rows,
{__MSG__: F.gather_row(msgs, msg_update_rows.tousertensor())},
inplace=False)
if len(msg_append_rows) > 0:
new_u, new_v = zip(*new_uv)
new_u = utils.toindex(new_u)
new_v = utils.toindex(new_v)
self._msg_graph.add_edges(new_u, new_v)
self._msg_frame.append(
{__MSG__: F.gather_row(msgs, msg_append_rows.tousertensor())}
)
"""
def update_edge(self, u=ALL, v=ALL, edge_func="default", eid=None):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
msgs = message_func(eb)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
It computes the new edge representations using the representations
of the source node, target node and the edge itself.
All node_reprs and edge_reprs are dictionaries.
def update_edges(self, edges=ALL, edge_func="default"):
"""Update features on the given edges.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
edge_func : callable
The update function.
eid : optional, edge, container or tensor
The edge to update on. If eid is not None then u and v are ignored.
Notes
-----
On multigraphs, if u and v are specified, then all the edges
between u and v will be updated.
"""
if edge_func == "default":
edge_func = self._edge_func
assert edge_func is not None
self._batch_update_edge(u, v, eid, edge_func)
def _batch_update_edge(self, u, v, eid, edge_func):
if is_all(u) and is_all(v) and eid is None:
u, v, eid = self._graph.edges()
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr()
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs)
else:
if eid is None:
if is_all(edges):
eid = ALL
u, v, _ = self._graph.edges()
elif isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v)
_, _, eid = self._graph.edge_ids(u, v)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr_by_id(new_edge_reprs, eid)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
self.set_e_repr(edge_func(eb), eid)
def recv(self,
u,
......@@ -1093,25 +988,6 @@ class DGLGraph(object):
apply_node_func="default"):
"""Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
of node u. If no message is found from the predecessors, reduce function
will be skipped.
The reduce function should be compatible with following signature:
(node_reprs, batched_messages) -> node_reprs
It computes the new node representations using the representations
of the in-coming edges (the same concept as messages).
The reduce function can also be pre-defined functions.
An optinoal apply_node function could be specified and should follow following
signature:
node_reprs -> node_reprs
All node_reprs and edge_reprs support tensor and dictionary types.
TODO(minjie): document on zero-in-degree case
TODO(minjie): document on how returned new features are merged with the old features
TODO(minjie): document on how many times UDFs will be called
......@@ -1141,11 +1017,13 @@ class DGLGraph(object):
v_is_all = is_all(v)
if v_is_all:
v = list(range(self.number_of_nodes()))
v = F.arange(0, self.number_of_nodes(), dtype=F.int64)
elif isinstance(v, int):
v = [v]
v = utils.toindex(v)
if len(v) == 0:
# no vertex to be triggered.
return
v = utils.toindex(v)
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self._msg_graph, v)
......@@ -1162,7 +1040,7 @@ class DGLGraph(object):
has_zero_degree = True
continue
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
v_data = self.get_n_repr(v_bkt)
uu, vv, in_msg_ids = self._msg_graph.in_edges(v_bkt)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
# Reshape the column tensor to (B, Deg, ...).
......@@ -1173,7 +1051,8 @@ class DGLGraph(object):
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
nb = NodeBatch(self, v_bkt, v_data, reshaped_in_msgs)
new_reprs.append(reduce_func(nb))
# TODO(minjie): clear partial messages
self.reset_messages()
......@@ -1195,26 +1074,26 @@ class DGLGraph(object):
self.set_n_repr(new_reprs, reordered_v)
def send_and_recv(self,
u=None, v=None,
edges,
message_func="default",
reduce_func="default",
apply_node_func="default",
eid=None):
"""Trigger the message function on u->v and update v, or on edge eid
and update the destination nodes.
apply_node_func="default"):
"""Send messages along edges and receive them on the targets.
Parameters
----------
u : optional, node, container or tensor
The source node(s).
v : optional, node, container or tensor
The destination node(s).
message_func : callable
The message function.
reduce_func : callable
The reduce function.
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
message_func : callable, optional
The message function. Registered function will be used if not
specified.
reduce_func : callable, optional
The reduce function. Registered function will be used if not
specified.
apply_node_func : callable, optional
The update function.
The update function. Registered function will be used if not
specified.
Notes
-----
......@@ -1223,69 +1102,58 @@ class DGLGraph(object):
"""
if message_func == "default":
message_func = self._message_func
elif isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if reduce_func == "default":
reduce_func = self._reduce_func
elif isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
assert message_func is not None
assert reduce_func is not None
if eid is None:
if u is None or v is None:
raise ValueError('u and v must be given if eid is None')
if isinstance(edges, tuple):
u, v = edges
u = utils.toindex(u)
v = utils.toindex(v)
# Rewrite u, v to handle edge broadcasting and multigraph.
u, v, eid = self._graph.edge_ids(u, v)
else:
eid = utils.toindex(edges)
u, v, _ = self._graph.find_edges(eid)
if len(u) == 0:
# no edges to be triggered
assert len(v) == 0
return
unique_v = utils.toindex(F.unique(v.tousertensor()))
if not self.is_multigraph:
executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func)
else:
eid = utils.toindex(eid)
if len(eid) == 0:
# no edges to be triggered
return
executor = None
if executor:
new_reprs = executor.run()
accum = executor.run()
unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
elif eid is not None:
_, v, _ = self._graph.find_edges(eid)
unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO(quan): replace with the new DegreeBucketingScheduler
self.send(eid=eid, message_func=message_func)
self.recv(unique_v, reduce_func, apply_node_func)
else:
# handle multiple message and reduce func
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if isinstance(reduce_func, (list, tuple)):
reduce_func = BundledReduceFunction(reduce_func)
# message func
u, v = utils.edge_broadcasting(u, v)
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs)
msg_frame = FrameRef()
msg_frame.append(msgs)
src_data = self.get_n_repr(u)
edge_data = self.get_e_repr(eid)
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
msgs = message_func(eb)
msg_frame = FrameRef(Frame(msgs))
# recv with degree bucketing
executor = scheduler.get_recv_executor(graph=self,
reduce_func=reduce_func,
message_frame=msg_frame,
edges=(u, v))
new_reprs = executor.run()
assert executor is not None
accum = executor.run()
unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
self._apply_nodes(unique_v, apply_node_func, reduce_accum=accum)
def pull(self,
v,
......@@ -1309,7 +1177,7 @@ class DGLGraph(object):
if len(v) == 0:
return
uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func, apply_node_func=None)
self.send_and_recv((uu, vv), message_func, reduce_func, apply_node_func=None)
unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func)
......@@ -1335,7 +1203,7 @@ class DGLGraph(object):
if len(u) == 0:
return
uu, vv, _ = self._graph.out_edges(u)
self.send_and_recv(uu, vv, message_func,
self.send_and_recv((uu, vv), message_func,
reduce_func, apply_node_func)
def update_all(self,
......@@ -1366,7 +1234,7 @@ class DGLGraph(object):
new_reprs = executor.run()
self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
else:
self.send(ALL, ALL, message_func)
self.send(ALL, message_func)
self.recv(ALL, reduce_func, apply_node_func)
def propagate(self,
......@@ -1401,7 +1269,7 @@ class DGLGraph(object):
else:
# NOTE: the iteration can return multiple edges at each step.
for u, v in traverser:
self.send_and_recv(u, v,
self.send_and_recv((u, v),
message_func, reduce_func, apply_node_func)
def subgraph(self, nodes):
......@@ -1586,18 +1454,19 @@ class DGLGraph(object):
----------
predicate : callable
The predicate should take in a dict of tensors whose values
are concatenation of edge representations by edge ID (same as
get_e_repr_by_id()), and return a boolean tensor with N elements
indicating which node satisfy the predicate.
edges : container or tensor
The edges to filter on
are concatenation of edge representations by edge ID,
and return a boolean tensor with N elements indicating which
node satisfy the predicate.
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns
-------
tensor
The filtered edges
"""
e_repr = self.get_e_repr_by_id(edges)
e_repr = self.get_e_repr(edges)
e_mask = predicate(e_repr)
if is_all(edges):
......
......@@ -5,10 +5,11 @@ import numpy as np
from .base import ALL, DGLError
from . import backend as F
from collections import defaultdict as ddict
from .function import message as fmsg
from .function import reducer as fred
from .udf import NodeBatch, EdgeBatch
from . import utils
from collections import defaultdict as ddict
from ._ffi.function import _init_api
......@@ -176,7 +177,7 @@ class DegreeBucketingExecutor(Executor):
# loop over each bucket
# FIXME (lingfan): handle zero-degree case
for deg, vv, msg_id in zip(self.degrees, self.dsts, self.msg_ids):
dst_reprs = self.g.get_n_repr(vv)
v_data = self.g.get_n_repr(vv)
in_msgs = self.msg_frame.select_rows(msg_id)
def _reshape_fn(msg):
msg_shape = F.shape(msg)
......@@ -184,7 +185,8 @@ class DegreeBucketingExecutor(Executor):
return F.reshape(msg, new_shape)
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs))
nb = NodeBatch(self.g, vv, v_data, reshaped_in_msgs)
new_reprs.append(self.rfunc(nb))
# Pack all reducer results together
keys = new_reprs[0].keys()
......@@ -320,7 +322,7 @@ class SendRecvExecutor(BasicExecutor):
@property
def edge_repr(self):
if self._edge_repr is None:
self._edge_repr = self.g.get_e_repr(self.u, self.v)
self._edge_repr = self.g.get_e_repr((self.u, self.v))
return self._edge_repr
def _build_adjmat(self):
......@@ -432,8 +434,11 @@ def _create_send_and_recv_exec(graph, **kwargs):
dst = kwargs.pop('dst')
mfunc = kwargs.pop('message_func')
rfunc = kwargs.pop('reduce_func')
if isinstance(mfunc, (list, tuple)) or isinstance(rfunc, (list, tuple)):
if (isinstance(mfunc, fmsg.BundledMessageFunction)
or isinstance(rfunc, fred.BundledReduceFunction)):
if not isinstance(mfunc, fmsg.BundledMessageFunction):
mfunc = fmsg.BundledMessageFunction(mfunc)
if not isinstance(rfunc, fred.BundledReduceFunction):
rfunc = fred.BundledReduceFunction(rfunc)
exec_cls = BundledSendRecvExecutor
else:
......
"""User-defined function related data structures."""
from __future__ import absolute_import
from collections import Mapping
from .base import ALL, is_all
from . import backend as F
from . import utils
class EdgeBatch(object):
"""The object that represents a batch of edges.
Parameters
----------
g : DGLGraph
The graph object.
edges : tuple of utils.Index
The edge tuple (u, v, eid). eid can be ALL
src_data : dict of tensors
The src node features
edge_data : dict of tensors
The edge features.
dst_data : dict of tensors
The dst node features
"""
def __init__(self, g, edges, src_data, edge_data, dst_data):
self._g = g
self._edges = edges
self._src_data = src_data
self._edge_data = edge_data
self._dst_data = dst_data
@property
def src(self):
"""Return the feature data of the source nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._src_data
@property
def dst(self):
"""Return the feature data of the destination nodes.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._dst_data
@property
def data(self):
"""Return the edge feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._edge_data
def edges(self):
"""Return the edges contained in this batch.
Returns
-------
tuple of tensors
The edge tuple (u, v, eid).
"""
if is_all(self._edges[2]):
self._edges[2] = utils.toindex(F.arange(
0, self._g.number_of_edges(), dtype=F.int64))
u, v, eid = self._edges
return (u.tousertensor(), v.tousertensor(), eid.tousertensor())
def batch_size(self):
"""Return the number of edges in this edge batch."""
return len(self._edges[0])
def __len__(self):
"""Return the number of edges in this edge batch."""
return self.batch_size()
class NodeBatch(object):
"""The object that represents a batch of nodes.
Parameters
----------
g : DGLGraph
The graph object.
nodes : utils.Index or ALL
The node ids.
data : dict of tensors
The node features
msgs : dict of tensors, optional
The messages.
"""
def __init__(self, g, nodes, data, msgs=None):
self._g = g
self._nodes = nodes
self._data = data
self._msgs = msgs
@property
def data(self):
"""Return the node feature data.
Returns
-------
dict of str to tensors
The feature data.
"""
return self._data
@property
def mailbox(self):
"""Return the received messages.
If no messages received, a None will be returned.
Returns
-------
dict of str to tensors
The message data.
"""
return self._msgs
def nodes(self):
"""Return the nodes contained in this batch.
Returns
-------
tensor
The nodes.
"""
if is_all(self._nodes):
self._nodes = utils.toindex(F.arange(
0, self._g.number_of_nodes(), dtype=F.int64))
return self._nodes.tousertensor()
def batch_size(self):
"""Return the number of nodes in this node batch."""
if is_all(self._nodes):
return self._g.number_of_nodes()
else:
return len(self._nodes)
def __len__(self):
"""Return the number of nodes in this node batch."""
return self.batch_size()
......@@ -20,8 +20,14 @@ class Index(object):
def _dispatch(self, data):
"""Store data based on its type."""
if isinstance(data, Tensor):
if not (F.dtype(data) == F.int64 and len(F.shape(data)) == 1):
if not (F.dtype(data) == F.int64):
raise ValueError('Index data must be an int64 vector, but got: %s' % str(data))
if len(F.shape(data)) > 1:
raise ValueError('Index data must be 1D int64 vector, but got: %s' % str(data))
if len(F.shape(data)) == 0:
# a tensor of one int
self._dispatch(int(data))
else:
self._user_tensor_data[F.get_context(data)] = data
elif isinstance(data, nd.NDArray):
if not (data.dtype == 'int64' and len(data.shape) == 1):
......@@ -343,3 +349,18 @@ def reorder(dict_like, index):
idx_ctx = index.tousertensor(F.get_context(val))
new_dict[key] = F.gather_row(val, idx_ctx)
return new_dict
def parse_edges_tuple(edges):
"""Parse the given edges and return the tuple.
Parameters
----------
edges : edges
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
Returns
-------
A tuple of (u, v, eid)
"""
pass
"""Views of DGLGraph."""
from __future__ import absolute_import
from collections import MutableMapping, namedtuple
from .base import ALL, is_all, DGLError
from . import backend as F
from . import utils
NodeSpace = namedtuple('NodeSpace', ['data'])
class NodeView(object):
"""A NodeView class to act as G.nodes for a DGLGraph.
Compared with networkx's NodeView, DGL's NodeView is not
a map. DGL's NodeView supports creating data view from
a given list of nodes.
Parameters
----------
graph : DGLGraph
The graph.
Examples
--------
TBD
"""
__slot__ = '_graph'
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_nodes()
def __getitem__(self, nodes):
if isinstance(nodes, slice):
# slice
if not (nodes.start is None and nodes.stop is None
and nodes.step is None):
raise DGLError('Currently only full slice ":" is supported')
return NodeSpace(data=NodeDataView(self._graph, ALL))
else:
return NodeSpace(data=NodeDataView(self._graph, nodes))
def __call__(self):
"""Return the nodes."""
return F.arange(0, len(self))
class NodeDataView(MutableMapping):
__slot__ = ['_graph', '_nodes']
def __init__(self, graph, nodes):
self._graph = graph
self._nodes = nodes
def __getitem__(self, key):
return self._graph.get_n_repr(self._nodes)[key]
def __setitem__(self, key, val):
self._graph.set_n_repr({key : val}, self._nodes)
def __delitem__(self, key):
if not is_all(self._nodes):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.ndata[key]` instead.')
self._graph.pop_n_repr(key)
def __len__(self):
return len(self._graph._node_frame)
def __iter__(self):
return iter(self._graph._node_frame)
def __repr__(self):
data = self._graph.get_n_repr(self._nodes)
return repr({key : data[key] for key in self._graph._node_frame})
EdgeSpace = namedtuple('EdgeSpace', ['data'])
class EdgeView(object):
__slot__ = ['_graph']
def __init__(self, graph):
self._graph = graph
def __len__(self):
return self._graph.number_of_edges()
def __getitem__(self, edges):
if isinstance(edges, slice):
# slice
if not (edges.start is None and edges.stop is None
and edges.step is None):
raise DGLError('Currently only full slice ":" is supported')
return EdgeSpace(data=EdgeDataView(self._graph, ALL))
else:
return EdgeSpace(data=EdgeDataView(self._graph, edges))
def __call__(self, *args, **kwargs):
"""Return all the edges."""
return self._graph.all_edges(*args, **kwargs)
class EdgeDataView(MutableMapping):
__slot__ = ['_graph', '_edges']
def __init__(self, graph, edges):
self._graph = graph
self._edges = edges
def __getitem__(self, key):
return self._graph.get_e_repr(self._edges)[key]
def __setitem__(self, key, val):
self._graph.set_e_repr({key : val}, self._edges)
def __delitem__(self, key):
if not is_all(self._edges):
raise DGLError('Delete feature data is not supported on only a subset'
' of nodes. Please use `del G.edata[key]` instead.')
self._graph.pop_e_repr(key)
def __len__(self):
return len(self._graph._edge_frame)
def __iter__(self):
return iter(self._graph._edge_frame)
def __repr__(self):
data = self._graph.get_e_repr(self._edges)
return repr({key : data[key] for key in self._graph._edge_frame})
......@@ -11,20 +11,20 @@ def check_eq(a, b):
assert a.shape == b.shape
assert mx.nd.sum(a == b).asnumpy() == int(np.prod(list(a.shape)))
def message_func(src, edge):
assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D
return {'m' : src['h']}
def message_func(edges):
assert len(edges.src['h'].shape) == 2
assert edges.src['h'].shape[1] == D
return {'m' : edges.src['h']}
def reduce_func(node, msgs):
msgs = msgs['m']
def reduce_func(nodes):
msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'m' : mx.nd.sum(msgs, 1)}
def apply_node_func(node):
return {'h' : node['h'] + node['m']}
def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['m']}
def generate_graph(grad=False):
g = DGLGraph()
......@@ -38,7 +38,7 @@ def generate_graph(grad=False):
ncol = mx.nd.random.normal(shape=(10, D))
if grad:
ncol.attach_grad()
g.set_n_repr({'h' : ncol})
g.ndata['h'] = ncol
return g
def test_batch_setter_getter():
......@@ -47,15 +47,15 @@ def test_batch_setter_getter():
g = generate_graph()
# set all nodes
g.set_n_repr({'h' : mx.nd.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
assert _pfc(g.ndata['h']) == [0.] * 10
# pop nodes
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == 0
assert len(g.ndata) == 0
g.set_n_repr({'h' : mx.nd.zeros((10, D))})
# set partial nodes
u = mx.nd.array([1, 3, 5], dtype='int64')
g.set_n_repr({'h' : mx.nd.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = mx.nd.array([1, 2, 3], dtype='int64')
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
......@@ -81,77 +81,77 @@ def test_batch_setter_getter():
9, 0, 16
'''
# set all edges
g.set_e_repr({'l' : mx.nd.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
g.edata['l'] = mx.nd.zeros((17, D))
assert _pfc(g.edata['l']) == [0.] * 17
# pop edges
assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == 0
g.set_e_repr({'l' : mx.nd.zeros((17, D))})
assert len(g.edata) == 0
g.edata['l'] = mx.nd.zeros((17, D))
# set partial edges (many-many)
u = mx.nd.array([0, 0, 2, 5, 9], dtype='int64')
v = mx.nd.array([1, 3, 9, 9, 0], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((5, D))}, u, v)
g.edges[u, v].data['l'] = mx.nd.ones((5, D))
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (many-one)
u = mx.nd.array([3, 4, 6], dtype='int64')
v = mx.nd.array([9], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (one-many)
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([4, 5, 6], dtype='int64')
g.set_e_repr({'l' : mx.nd.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = mx.nd.ones((3, D))
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# get partial edges (many-many)
u = mx.nd.array([0, 6, 0], dtype='int64')
v = mx.nd.array([6, 9, 7], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one)
u = mx.nd.array([5, 6, 7], dtype='int64')
v = mx.nd.array([9], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many)
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([3, 4, 5], dtype='int64')
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
with mx.autograd.record():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
h1 = g.ndata['h']
h1.attach_grad()
# partial set
v = mx.nd.array([1, 2, 8], dtype='int64')
hh = mx.nd.zeros((len(v), D))
hh.attach_grad()
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
h2 = g.ndata['h']
h2.backward(mx.nd.ones((10, D)) * 2)
check_eq(h1.grad[:,0], mx.nd.array([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], mx.nd.array([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
def _fmsg(edges):
assert edges.src['h'].shape == (5, D)
return {'m' : edges.src['h']}
g.register_message_func(_fmsg)
# many-many send
u = mx.nd.array([0, 0, 0, 0, 0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v)
g.send((u, v))
# one-many send
u = mx.nd.array([0], dtype='int64')
v = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
g.send(u, v)
g.send((u, v))
# many-one send
u = mx.nd.array([1, 2, 3, 4, 5], dtype='int64')
v = mx.nd.array([9], dtype='int64')
g.send(u, v)
g.send((u, v))
def test_batch_recv():
# basic recv test
......@@ -162,7 +162,7 @@ def test_batch_recv():
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
reduce_msg_shapes.clear()
g.send(u, v)
g.send((u, v))
#g.recv(th.unique(v))
#assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
#reduce_msg_shapes.clear()
......@@ -177,7 +177,7 @@ def test_update_routines():
reduce_msg_shapes.clear()
u = mx.nd.array([0, 0, 0, 4, 5, 6], dtype='int64')
v = mx.nd.array([1, 2, 3, 9, 9, 9], dtype='int64')
g.send_and_recv(u, v)
g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
......@@ -208,14 +208,14 @@ def test_reduce_0deg():
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : node['h'] + msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(5, 5))
g.set_n_repr({'h': old_repr})
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[1:].asnumpy(), old_repr[1:].asnumpy())
assert np.allclose(new_repr[0].asnumpy(), old_repr.sum(0).asnumpy())
......@@ -224,25 +224,25 @@ def test_pull_0deg():
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[1].asnumpy())
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
old_repr = mx.nd.random.normal(shape=(2, 5))
g.set_n_repr({'h' : old_repr})
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert np.allclose(new_repr[0].asnumpy(), old_repr[0].asnumpy())
assert np.allclose(new_repr[1].asnumpy(), old_repr[0].asnumpy())
......
......@@ -10,20 +10,20 @@ def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(src, edge):
assert len(src['h'].shape) == 2
assert src['h'].shape[1] == D
return {'m' : src['h']}
def message_func(edges):
assert len(edges.src['h'].shape) == 2
assert edges.src['h'].shape[1] == D
return {'m' : edges.src['h']}
def reduce_func(node, msgs):
msgs = msgs['m']
def reduce_func(nodes):
msgs = nodes.mailbox['m']
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return {'accum' : th.sum(msgs, 1)}
def apply_node_func(node):
return {'h' : node['h'] + node['accum']}
def apply_node_func(nodes):
return {'h' : nodes.data['h'] + nodes.data['accum']}
def generate_graph(grad=False):
g = DGLGraph()
......@@ -36,10 +36,11 @@ def generate_graph(grad=False):
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad)
accumcol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr({'h' : ncol})
g.ndata['h'] = ncol
g.edata['w'] = ecol
g.set_n_initializer(lambda shape, dtype : th.zeros(shape))
g.set_e_initializer(lambda shape, dtype : th.zeros(shape))
return g
def test_batch_setter_getter():
......@@ -47,20 +48,20 @@ def test_batch_setter_getter():
return list(x.numpy()[:,0])
g = generate_graph()
# set all nodes
g.set_n_repr({'h' : th.zeros((10, D))})
assert _pfc(g.get_n_repr()['h']) == [0.] * 10
g.ndata['h'] = th.zeros((10, D))
assert th.allclose(g.ndata['h'], th.zeros((10, D)))
# pop nodes
old_len = len(g.get_n_repr())
old_len = len(g.ndata)
assert _pfc(g.pop_n_repr('h')) == [0.] * 10
assert len(g.get_n_repr()) == old_len - 1
g.set_n_repr({'h' : th.zeros((10, D))})
assert len(g.ndata) == old_len - 1
g.ndata['h'] = th.zeros((10, D))
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr({'h' : th.ones((3, D))}, u)
assert _pfc(g.get_n_repr()['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
g.nodes[u].data['h'] = th.ones((3, D))
assert _pfc(g.ndata['h']) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)['h']) == [1., 0., 1.]
assert _pfc(g.nodes[u].data['h']) == [1., 0., 1.]
'''
s, d, eid
......@@ -83,75 +84,75 @@ def test_batch_setter_getter():
9, 0, 16
'''
# set all edges
g.set_e_repr({'l' : th.zeros((17, D))})
assert _pfc(g.get_e_repr()['l']) == [0.] * 17
g.edata['l'] = th.zeros((17, D))
assert _pfc(g.edata['l']) == [0.] * 17
# pop edges
old_len = len(g.get_e_repr())
old_len = len(g.edata)
assert _pfc(g.pop_e_repr('l')) == [0.] * 17
assert len(g.get_e_repr()) == old_len - 1
g.set_e_repr({'l' : th.zeros((17, D))})
assert len(g.edata) == old_len - 1
g.edata['l'] = th.zeros((17, D))
# set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr({'l' : th.ones((5, D))}, u, v)
g.edges[u, v].data['l'] = th.ones((5, D))
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (many-one)
u = th.tensor([3, 4, 6])
v = th.tensor([9])
g.set_e_repr({'l' : th.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = th.ones((3, D))
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# set partial edges (one-many)
u = th.tensor([0])
v = th.tensor([4, 5, 6])
g.set_e_repr({'l' : th.ones((3, D))}, u, v)
g.edges[u, v].data['l'] = th.ones((3, D))
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()['l']) == truth
assert _pfc(g.edata['l']) == truth
# get partial edges (many-many)
u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (many-one)
u = th.tensor([5, 6, 7])
v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 0.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 0.]
# get partial edges (one-many)
u = th.tensor([0])
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)['l']) == [1., 1., 1.]
assert _pfc(g.edges[u, v].data['l']) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()['h']
h1 = g.ndata['h']
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr({'h' : hh}, v)
h2 = g.get_n_repr()['h']
g.nodes[v].data['h'] = hh
h2 = g.ndata['h']
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(src, edge):
assert src['h'].shape == (5, D)
return {'m' : src['h']}
def _fmsg(edges):
assert edges.src['h'].shape == (5, D)
return {'m' : edges.src['h']}
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
g.send((u, v))
# one-many send
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
g.send((u, v))
# many-one send
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.send(u, v)
g.send((u, v))
def test_batch_recv():
# basic recv test
......@@ -162,11 +163,25 @@ def test_batch_recv():
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.send(u, v)
g.send((u, v))
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_edges():
def _upd(edges):
return {'w' : edges.data['w'] * 2}
g = generate_graph()
g.register_edge_func(_upd)
old = g.edata['w']
g.update_edges()
assert th.allclose(old * 2, g.edata['w'])
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_edges((u, v), lambda edges : {'w' : edges.data['w'] * 0.})
eid = g.edge_ids(u, v)
assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func)
......@@ -177,7 +192,7 @@ def test_update_routines():
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.send_and_recv(u, v)
g.send_and_recv((u, v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
......@@ -208,14 +223,14 @@ def test_reduce_0deg():
g.add_edge(2, 0)
g.add_edge(3, 0)
g.add_edge(4, 0)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : node['h'] + msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.data['h'] + nodes.mailbox['m'].sum(1)}
old_repr = th.randn(5, 5)
g.set_n_repr({'h' : old_repr})
g.ndata['h'] = old_repr
g.update_all(_message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0))
......@@ -224,26 +239,26 @@ def test_pull_0deg():
g = DGLGraph()
g.add_nodes(2)
g.add_edge(0, 1)
def _message(src, edge):
return {'m' : src['h']}
def _reduce(node, msgs):
return {'h' : msgs['m'].sum(1)}
def _message(edges):
return {'m' : edges.src['h']}
def _reduce(nodes):
return {'h' : nodes.mailbox['m'].sum(1)}
old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr})
g.ndata['h'] = old_repr
g.pull(0, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5)
g.set_n_repr({'h' : old_repr})
g.ndata['h'] = old_repr
g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr()['h']
new_repr = g.ndata['h']
assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0])
......@@ -253,27 +268,26 @@ def _disabled_test_send_twice():
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(src, edge):
return {'a': src['a']}
def _message_b(src, edge):
return {'a': src['a'] * 3}
def _reduce(node, msgs):
assert msgs is not None
return {'a': msgs['a'].max(1)[0]}
def _message_a(edges):
return {'a': edges.src['a']}
def _message_b(edges):
return {'a': edges.src['a'] * 3}
def _reduce(nodes):
return {'a': nodes.mailbox['a'].max(1)[0]}
old_repr = th.randn(3, 5)
g.set_n_repr({'a': old_repr})
g.send(0, 1, _message_a)
g.send(0, 1, _message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((0, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr[0] * 3)
g.set_n_repr({'a': old_repr})
g.send(0, 1, _message_a)
g.send(2, 1, _message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = old_repr
g.send((0, 1), _message_a)
g.send((2, 1), _message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], th.stack([old_repr[0], old_repr[2] * 3], 0).max(0)[0])
def test_send_multigraph():
......@@ -284,64 +298,63 @@ def test_send_multigraph():
g.add_edge(0, 1)
g.add_edge(2, 1)
def _message_a(src, edge):
return {'a': edge['a']}
def _message_b(src, edge):
return {'a': edge['a'] * 3}
def _reduce(node, msgs):
assert msgs is not None
return {'a': msgs['a'].max(1)[0]}
def _message_a(edges):
return {'a': edges.data['a']}
def _message_b(edges):
return {'a': edges.data['a'] * 3}
def _reduce(nodes):
return {'a': nodes.mailbox['a'].max(1)[0]}
def answer(*args):
return th.stack(args, 0).max(0)[0]
# send by eid
old_repr = th.randn(4, 5)
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=[0, 2], message_func=_message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send([0, 2], message_func=_message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2]))
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=[0, 2, 3], message_func=_message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send([0, 2, 3], message_func=_message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
# send on multigraph
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send([0, 2], [1, 1], _message_a)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send(([0, 2], [1, 1]), _message_a)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], old_repr.max(0)[0])
# consecutive send and send_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(2, 1, _message_a)
g.send(eid=[0, 1], message_func=_message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send((2, 1), _message_a)
g.send([0, 1], message_func=_message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0] * 3, old_repr[1] * 3, old_repr[3]))
# consecutive send_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send(eid=0, message_func=_message_a)
g.send(eid=1, message_func=_message_b)
g.recv([1], _reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send(0, message_func=_message_a)
g.send(1, message_func=_message_b)
g.recv(1, _reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[1] * 3))
# send_and_recv_on
g.set_n_repr({'a': th.zeros(3, 5)})
g.set_e_repr({'a': old_repr})
g.send_and_recv(eid=[0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.get_n_repr()['a']
g.ndata['a'] = th.zeros(3, 5)
g.edata['a'] = old_repr
g.send_and_recv([0, 2, 3], message_func=_message_a, reduce_func=_reduce)
new_repr = g.ndata['a']
assert th.allclose(new_repr[1], answer(old_repr[0], old_repr[2], old_repr[3]))
assert th.allclose(new_repr[[0, 2]], th.zeros(2, 5))
......@@ -353,29 +366,25 @@ def test_dynamic_addition():
# Test node addition
g.add_nodes(N)
g.set_n_repr({'h1': th.randn(N, D),
g.ndata.update({'h1': th.randn(N, D),
'h2': th.randn(N, D)})
g.add_nodes(3)
n_repr = g.get_n_repr()
assert n_repr['h1'].shape[0] == n_repr['h2'].shape[0] == N + 3
assert g.ndata['h1'].shape[0] == g.ndata['h2'].shape[0] == N + 3
# Test edge addition
g.add_edge(0, 1)
g.add_edge(1, 0)
g.set_e_repr({'h1': th.randn(2, D),
g.edata.update({'h1': th.randn(2, D),
'h2': th.randn(2, D)})
e_repr = g.get_e_repr()
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 2
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 2
g.add_edges([0, 2], [2, 0])
e_repr = g.get_e_repr()
g.set_e_repr({'h1': th.randn(4, D)})
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 4
g.edata['h1'] = th.randn(4, D)
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 4
g.add_edge(1, 2)
g.set_e_repr_by_id({'h1': th.randn(1, D)}, eid=4)
e_repr = g.get_e_repr()
assert e_repr['h1'].shape[0] == e_repr['h2'].shape[0] == 5
g.edges[4].data['h1'] = th.randn(1, D)
assert g.edata['h1'].shape[0] == g.edata['h2'].shape[0] == 5
if __name__ == '__main__':
......@@ -383,6 +392,7 @@ if __name__ == '__main__':
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_edges()
test_update_routines()
test_reduce_0deg()
test_pull_0deg()
......
......@@ -18,8 +18,8 @@ def tree1():
g.add_edge(4, 1)
g.add_edge(1, 0)
g.add_edge(2, 0)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr({'h' : th.randn(4, 10)})
g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.edata['h'] = th.randn(4, 10)
return g
def tree2():
......@@ -37,17 +37,17 @@ def tree2():
g.add_edge(0, 4)
g.add_edge(4, 1)
g.add_edge(3, 1)
g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr({'h' : th.randn(4, 10)})
g.ndata['h'] = th.Tensor([0, 1, 2, 3, 4])
g.edata['h'] = th.randn(4, 10)
return g
def test_batch_unbatch():
t1 = tree1()
t2 = tree2()
n1 = t1.get_n_repr()['h']
n2 = t2.get_n_repr()['h']
e1 = t1.get_e_repr()['h']
e2 = t2.get_e_repr()['h']
n1 = t1.ndata['h']
n2 = t2.ndata['h']
e1 = t1.edata['h']
e2 = t2.edata['h']
bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10
......@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr()['h'], tt1.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr()['h'], tt1.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr()['h'], tt2.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr()['h'], tt2.get_e_repr()['h'])
assert th.allclose(t1.ndata['h'], tt1.ndata['h'])
assert th.allclose(t1.edata['h'], tt1.edata['h'])
assert th.allclose(t2.ndata['h'], tt2.ndata['h'])
assert th.allclose(t2.edata['h'], tt2.edata['h'])
def test_batch_unbatch1():
t1 = tree1()
......@@ -74,29 +74,29 @@ def test_batch_unbatch1():
assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr()['h'], s1.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr()['h'], s1.get_e_repr()['h'])
assert th.allclose(t1.get_n_repr()['h'], s2.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr()['h'], s2.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr()['h'], s3.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr()['h'], s3.get_e_repr()['h'])
assert th.allclose(t2.ndata['h'], s1.ndata['h'])
assert th.allclose(t2.edata['h'], s1.edata['h'])
assert th.allclose(t1.ndata['h'], s2.ndata['h'])
assert th.allclose(t1.edata['h'], s2.edata['h'])
assert th.allclose(t2.ndata['h'], s3.ndata['h'])
assert th.allclose(t2.edata['h'], s3.edata['h'])
def test_batch_sendrecv():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
u = [3, 4, 2 + 5, 0 + 5]
v = [1, 1, 4 + 5, 4 + 5]
bg.send(u, v)
bg.send((u, v))
bg.recv(v)
t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][1] == 7
assert t2.get_n_repr()['h'][4] == 2
assert t1.ndata['h'][1] == 7
assert t2.ndata['h'][4] == 2
def test_batch_propagate():
......@@ -104,8 +104,8 @@ def test_batch_propagate():
t2 = tree2()
bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
bg.register_message_func(lambda edges: {'m' : edges.src['h']})
bg.register_reduce_func(lambda nodes: {'h' : th.sum(nodes.mailbox['m'], 1)})
# get leaves.
order = []
......@@ -123,23 +123,23 @@ def test_batch_propagate():
bg.propagate(traverser=order)
t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()['h'][0] == 9
assert t2.get_n_repr()['h'][1] == 5
assert t1.ndata['h'][0] == 9
assert t2.ndata['h'][1] == 5
def test_batched_edge_ordering():
g1 = dgl.DGLGraph()
g1.add_nodes(6)
g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
e1 = th.randn(5, 10)
g1.set_e_repr({'h' : e1})
g1.edata['h'] = e1
g2 = dgl.DGLGraph()
g2.add_nodes(6)
g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
e2 = th.randn(6, 10)
g2.set_e_repr({'h' : e2})
g2.edata['h'] = e2
g = dgl.batch([g1, g2])
r1 = g.get_e_repr()['h'][g.edge_id(4, 5)]
r2 = g1.get_e_repr()['h'][g1.edge_id(4, 5)]
r1 = g.edata['h'][g.edge_id(4, 5)]
r2 = g1.edata['h'][g1.edge_id(4, 5)]
assert th.equal(r1, r2)
def test_batch_no_edge():
......
......@@ -12,8 +12,8 @@ def test_filter():
n_repr[[1, 3]] = 1
e_repr[[1, 3]] = 1
g.set_n_repr({'a': n_repr})
g.set_e_repr({'a': e_repr})
g.ndata['a'] = n_repr
g.edata['a'] = e_repr
def predicate(r):
return r['a'].max(1)[0] > 0
......
......@@ -6,7 +6,7 @@ def generate_graph():
g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float)
g.set_n_repr({'h': h})
g.ndata['h'] = h
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
......@@ -15,29 +15,11 @@ def generate_graph():
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr({'h' : h})
g.edata['h'] = h
return g
def generate_graph1():
"""graph with anonymous repr"""
g = dgl.DGLGraph()
g.add_nodes(10) # 10 nodes.
h = th.arange(1, 11, dtype=th.float)
h = th.arange(1, 11, dtype=th.float)
g.set_n_repr(h)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
h = th.tensor([1., 2., 1., 3., 1., 4., 1., 5., 1., 6.,\
1., 7., 1., 8., 1., 9., 10.])
g.set_e_repr(h)
return g
def reducer_both(node, msgs):
return {'h' : th.sum(msgs['m'], 1)}
def reducer_both(nodes):
return {'h' : th.sum(nodes.mailbox['m'], 1)}
def test_copy_src():
# copy_src with both fields
......@@ -45,7 +27,7 @@ def test_copy_src():
g.register_message_func(fn.copy_src(src='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge():
......@@ -54,7 +36,7 @@ def test_copy_edge():
g.register_message_func(fn.copy_edge(edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
assert th.allclose(g.ndata['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge():
......@@ -63,7 +45,7 @@ def test_src_mul_edge():
g.register_message_func(fn.src_mul_edge(src='h', edge='h', out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
assert th.allclose(g.ndata['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment