Commit 3e76bcc0 authored by Minjie Wang's avatar Minjie Wang
Browse files

remove anonymous repr

parent fb6be9fb
...@@ -56,7 +56,7 @@ class GCN(nn.Module): ...@@ -56,7 +56,7 @@ class GCN(nn.Module):
g.apply_nodes(apply_node_func= g.apply_nodes(apply_node_func=
lambda node: F.dropout(node['h'], p=self.dropout)) lambda node: F.dropout(node['h'], p=self.dropout))
self.g.update_all(fn.copy_src(src='h', out='m'), self.g.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msgs='m', out='h'), fn.sum(msg='m', out='h'),
layer) layer)
return self.g.pop_n_repr('h') return self.g.pop_n_repr('h')
......
...@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__ ...@@ -11,5 +11,5 @@ from ._ffi.base import DGLError, __version__
from .base import ALL from .base import ALL
from .batched_graph import * from .batched_graph import *
from .generator import * from .generator import *
from .graph import DGLGraph, __MSG__, __REPR__ from .graph import DGLGraph
from .subgraph import DGLSubGraph from .subgraph import DGLSubGraph
...@@ -11,7 +11,4 @@ ALL = "__ALL__" ...@@ -11,7 +11,4 @@ ALL = "__ALL__"
def is_all(arg): def is_all(arg):
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
dgl_warning = warnings.warn dgl_warning = warnings.warn
...@@ -4,17 +4,25 @@ from __future__ import absolute_import ...@@ -4,17 +4,25 @@ from __future__ import absolute_import
import operator import operator
import dgl.backend as F import dgl.backend as F
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"] __all__ = ["src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object): class MessageFunction(object):
"""Base builtin message function class."""
def __call__(self, src, edge): def __call__(self, src, edge):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction): ...@@ -22,12 +30,6 @@ class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
# cannot perform check for udf
if isinstance(fn, MessageFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple message is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self, g): def is_spmv_supported(self, g):
...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction): ...@@ -43,11 +45,8 @@ class BundledMessageFunction(MessageFunction):
if ret is None: if ret is None:
ret = msg ret = msg
else: else:
try:
# ret and msg must be dict # ret and msg must be dict
ret.update(msg) ret.update(msg)
except:
raise RuntimeError("Must specify out field for multiple message")
return ret return ret
def name(self): def name(self):
...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction): ...@@ -55,25 +54,26 @@ class BundledMessageFunction(MessageFunction):
def _is_spmv_supported_node_feat(g, field): def _is_spmv_supported_node_feat(g, field):
if field is None: """Return whether the node feature shape supports SPMV optimization.
feat = g.get_n_repr()
else: Only scalar and vector features are supported currently.
"""
feat = g.get_n_repr()[field] feat = g.get_n_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or len(shape) == 2 return len(shape) == 1 or len(shape) == 2
def _is_spmv_supported_edge_feat(g, field): def _is_spmv_supported_edge_feat(g, field):
# check shape, only scalar edge feature can be optimized at the moment """Return whether the edge feature shape supports SPMV optimization.
if field is None:
feat = g.get_e_repr() Only scalar feature is supported currently.
else: """
feat = g.get_e_repr()[field] feat = g.get_e_repr()[field]
shape = F.shape(feat) shape = F.shape(feat)
return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1) return len(shape) == 1 or (len(shape) == 2 and shape[1] == 1)
class SrcMulEdgeMessageFunction(MessageFunction): class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None): def __init__(self, mul_op, src_field, edge_field, out_field):
self.mul_op = mul_op self.mul_op = mul_op
self.src_field = src_field self.src_field = src_field
self.edge_field = edge_field self.edge_field = edge_field
...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction): ...@@ -84,21 +84,14 @@ class SrcMulEdgeMessageFunction(MessageFunction):
and _is_spmv_supported_edge_feat(g, self.edge_field) and _is_spmv_supported_edge_feat(g, self.edge_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: ret = self.mul_op(src[self.src_field], edge[self.edge_field])
src = src[self.src_field]
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret} return {self.out_field : ret}
def name(self): def name(self):
return "src_mul_edge" return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction): class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None): def __init__(self, src_field, out_field):
self.src_field = src_field self.src_field = src_field
self.out_field = out_field self.out_field = out_field
...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction): ...@@ -106,14 +99,7 @@ class CopySrcMessageFunction(MessageFunction):
return _is_spmv_supported_node_feat(g, self.src_field) return _is_spmv_supported_node_feat(g, self.src_field)
def __call__(self, src, edge): def __call__(self, src, edge):
if self.src_field is not None: return {self.out_field : src[self.src_field]}
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return "copy_src" return "copy_src"
...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction): ...@@ -142,14 +128,41 @@ class CopyEdgeMessageFunction(MessageFunction):
return "copy_edge" return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None): def src_mul_edge(src, edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message by multiplying source node features
with edge features.
Parameters
----------
src : str
The source feature name.
edge : str
The edge feature name.
out : str
The output message name.
"""
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out) return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None): def copy_src(src, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using source node feature.
Parameters
----------
src : str
The source feature name.
out : str
The output message name.
"""
return CopySrcMessageFunction(src, out) return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None): def copy_edge(edge, out):
"""TODO(minjie): docstring """ """Builtin message function that computes message using edge feature.
Parameters
----------
edge : str
The edge feature name.
out : str
The output message name.
"""
return CopyEdgeMessageFunction(edge, out) return CopyEdgeMessageFunction(edge, out)
...@@ -3,27 +3,30 @@ from __future__ import absolute_import ...@@ -3,27 +3,30 @@ from __future__ import absolute_import
from .. import backend as F from .. import backend as F
__all__ = ["ReduceFunction", "sum", "max"] __all__ = ["sum", "max"]
class ReduceFunction(object): class ReduceFunction(object):
"""Base builtin reduce function class."""
def __call__(self, node, msgs): def __call__(self, node, msgs):
"""Regular computation of this builtin.
This will be used when optimization is not available.
"""
raise NotImplementedError raise NotImplementedError
def name(self): def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError raise NotImplementedError
def is_spmv_supported(self): def is_spmv_supported(self):
"""Return whether the SPMV optimization is supported."""
raise NotImplementedError raise NotImplementedError
class BundledReduceFunction(ReduceFunction): class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list): def __init__(self, fn_list):
if not isinstance(fn_list, (list, tuple)): if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list] fn_list = [fn_list]
else:
# sanity check on out field
for fn in fn_list:
if isinstance(fn, ReduceFunction) and fn.out_field is None:
raise RuntimeError("Not specifying out field for multiple reduce is ambiguous")
self.fn_list = fn_list self.fn_list = fn_list
def is_spmv_supported(self): def is_spmv_supported(self):
...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction): ...@@ -39,51 +42,50 @@ class BundledReduceFunction(ReduceFunction):
if ret is None: if ret is None:
ret = rpr ret = rpr
else: else:
try:
# ret and rpr must be dict # ret and rpr must be dict
ret.update(rpr) ret.update(rpr)
except:
raise RuntimeError("Must specify out field for multiple reudce")
return ret return ret
def name(self): def name(self):
return "bundled" return "bundled"
class ReducerFunctionTemplate(ReduceFunction): class ReducerFunctionTemplate(ReduceFunction):
def __init__(self, name, batch_op, nonbatch_op, msg_field=None, out_field=None): def __init__(self, name, op, msg_field, out_field):
self.name = name self.name = name
self.batch_op = batch_op self.op = op
self.nonbatch_op = nonbatch_op
self.msg_field = msg_field self.msg_field = msg_field
self.out_field = out_field self.out_field = out_field
def is_spmv_supported(self): def is_spmv_supported(self):
# TODO: support max # NOTE: only sum is supported right now.
return self.name == "sum" return self.name == "sum"
def __call__(self, node, msgs): def __call__(self, node, msgs):
if isinstance(msgs, list): return {self.out_field : self.op(msgs[self.msg_field], 1)}
if self.msg_field is None:
ret = self.nonbatch_op(msgs)
else:
ret = self.nonbatch_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_op(msgs, 1)
else:
ret = self.batch_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self): def name(self):
return self.name return self.name
_python_sum = sum def sum(msg, out):
def sum(msgs=None, out=None): """Builtin reduce function that aggregates messages by sum.
return ReducerFunctionTemplate("sum", F.sum, _python_sum, msgs, out)
Parameters
----------
msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("sum", F.sum, msg, out)
def max(msg, out):
"""Builtin reduce function that aggregates messages by max.
_python_max = max Parameters
def max(msgs=None, out=None): ----------
return ReducerFunctionTemplate("max", F.max, _python_max, msgs, out) msg : str
The message name.
out : str
The output node feature name.
"""
return ReducerFunctionTemplate("max", F.max, msg, out)
...@@ -6,7 +6,7 @@ import networkx as nx ...@@ -6,7 +6,7 @@ import networkx as nx
import numpy as np import numpy as np
import dgl import dgl
from .base import ALL, is_all, __MSG__, __REPR__ from .base import ALL, is_all, DGLError, dgl_warning
from . import backend as F from . import backend as F
from .backend import Tensor from .backend import Tensor
from .frame import FrameRef, merge_frames from .frame import FrameRef, merge_frames
...@@ -22,7 +22,6 @@ class DGLGraph(object): ...@@ -22,7 +22,6 @@ class DGLGraph(object):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
TODO(minjie): document of batching semantics TODO(minjie): document of batching semantics
TODO(minjie): document of __REPR__ semantics
Parameters Parameters
---------- ----------
...@@ -448,7 +447,9 @@ class DGLGraph(object): ...@@ -448,7 +447,9 @@ class DGLGraph(object):
The nx graph The nx graph
""" """
nx_graph = self._graph.to_networkx() nx_graph = self._graph.to_networkx()
#TODO: attributes #TODO(minjie): attributes
dgl_warning('to_networkx currently does not support converting'
' node/edge features automatically.')
return nx_graph return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None): def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
...@@ -550,20 +551,17 @@ class DGLGraph(object): ...@@ -550,20 +551,17 @@ class DGLGraph(object):
def set_n_repr(self, hu, u=ALL, inplace=False): def set_n_repr(self, hu, u=ALL, inplace=False):
"""Set node(s) representation. """Set node(s) representation.
To set multiple node representations at once, pass `u` with a tensor or `hu` is a dictionary from the feature name to feature tensor. Each tensor
a supported container of node ids. In this case, `hu` must be a tensor is of shape (B, D1, D2, ...), where B is the number of nodes to be updated,
of shape (B, D1, D2, ...), where B is the number of the nodes and and (D1, D2, ...) be the shape of the node representation tensor. The
(D1, D2, ...) is the shape of the node representation tensor. length of the given node ids must match B (i.e, len(u) == B).
Dictionary type is also supported for `hu`. In this case, each item
will be treated as separate attribute of the nodes.
All update will be done out-placely to work with autograd unless the inplace All update will be done out-placely to work with autograd unless the inplace
flag is true. flag is true.
Parameters Parameters
---------- ----------
hu : tensor or dict of tensor hu : dict of tensor
Node representation. Node representation.
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
...@@ -571,32 +569,31 @@ class DGLGraph(object): ...@@ -571,32 +569,31 @@ class DGLGraph(object):
True if the update is done inplacely True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(hu):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(hu))
if is_all(u): if is_all(u):
num_nodes = self.number_of_nodes() num_nodes = self.number_of_nodes()
else: else:
u = utils.toindex(u) u = utils.toindex(u)
num_nodes = len(u) num_nodes = len(u)
if utils.is_dict_like(hu):
for key, val in hu.items(): for key, val in hu.items():
assert F.shape(val)[0] == num_nodes nfeats = F.shape(val)[0]
else: if nfeats != num_nodes:
assert F.shape(hu)[0] == num_nodes raise DGLError('Expect number of features to match number of nodes (len(u)).'
' Got %d and %d instead.' % (nfeats, num_nodes))
# set # set
if is_all(u): if is_all(u):
if utils.is_dict_like(hu):
for key, val in hu.items(): for key, val in hu.items():
self._node_frame[key] = val self._node_frame[key] = val
else: else:
self._node_frame[__REPR__] = hu
else:
if utils.is_dict_like(hu):
self._node_frame.update_rows(u, hu, inplace=inplace) self._node_frame.update_rows(u, hu, inplace=inplace)
else:
self._node_frame.update_rows(u, {__REPR__ : hu}, inplace=inplace)
def get_n_repr(self, u=ALL): def get_n_repr(self, u=ALL):
"""Get node(s) representation. """Get node(s) representation.
The returned feature tensor batches multiple node features on the first dimension.
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
...@@ -605,23 +602,17 @@ class DGLGraph(object): ...@@ -605,23 +602,17 @@ class DGLGraph(object):
Returns Returns
------- -------
dict dict
Representation dict Representation dict from feature name to feature tensor.
""" """
if len(self.node_attr_schemes()) == 0: if len(self.node_attr_schemes()) == 0:
return dict() return dict()
if is_all(u): if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame[__REPR__]
else:
return dict(self._node_frame) return dict(self._node_frame)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
return self._node_frame.select_rows(u)[__REPR__]
else:
return self._node_frame.select_rows(u) return self._node_frame.select_rows(u)
def pop_n_repr(self, key=__REPR__): def pop_n_repr(self, key):
"""Get and remove the specified node repr. """Get and remove the specified node repr.
Parameters Parameters
...@@ -636,23 +627,19 @@ class DGLGraph(object): ...@@ -636,23 +627,19 @@ class DGLGraph(object):
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
def set_e_repr(self, h_uv, u=ALL, v=ALL, inplace=False): def set_e_repr(self, he, u=ALL, v=ALL, inplace=False):
"""Set edge(s) representation. """Set edge(s) representation.
To set multiple edge representations at once, pass `u` and `v` with tensors or `he` is a dictionary from the feature name to feature tensor. Each tensor
supported containers of node ids. In this case, `h_uv` must be a tensor is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
of shape (B, D1, D2, ...), where B is the number of the edges and and (D1, D2, ...) be the shape of the edge representation tensor.
(D1, D2, ...) is the shape of the edge representation tensor.
Dictionary type is also supported for `h_uv`. In this case, each item
will be treated as separate attribute of the edges.
All update will be done out-placely to work with autograd unless the inplace All update will be done out-placely to work with autograd unless the inplace
flag is true. flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor he : tensor or dict of tensor
Edge representation. Edge representation.
u : node, container or tensor u : node, container or tensor
The source node(s). The source node(s).
...@@ -662,26 +649,33 @@ class DGLGraph(object): ...@@ -662,26 +649,33 @@ class DGLGraph(object):
True if the update is done inplacely True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
assert u_is_all == v_is_all assert u_is_all == v_is_all
if u_is_all: if u_is_all:
self.set_e_repr_by_id(h_uv, eid=ALL, inplace=inplace) self.set_e_repr_by_id(he, eid=ALL, inplace=inplace)
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
self.set_e_repr_by_id(h_uv, eid=eid, inplace=inplace) self.set_e_repr_by_id(he, eid=eid, inplace=inplace)
def set_e_repr_by_id(self, h_uv, eid=ALL, inplace=False): def set_e_repr_by_id(self, he, eid=ALL, inplace=False):
"""Set edge(s) representation by edge id. """Set edge(s) representation by edge id.
`he` is a dictionary from the feature name to feature tensor. Each tensor
is of shape (B, D1, D2, ...), where B is the number of edges to be updated,
and (D1, D2, ...) be the shape of the edge representation tensor.
All update will be done out-placely to work with autograd unless the inplace All update will be done out-placely to work with autograd unless the inplace
flag is true. flag is true.
Parameters Parameters
---------- ----------
h_uv : tensor or dict of tensor he : tensor or dict of tensor
Edge representation. Edge representation.
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
...@@ -689,30 +683,27 @@ class DGLGraph(object): ...@@ -689,30 +683,27 @@ class DGLGraph(object):
True if the update is done inplacely True if the update is done inplacely
""" """
# sanity check # sanity check
if not utils.is_dict_like(he):
raise DGLError('Expect dictionary type for feature data.'
' Got "%s" instead.' % type(he))
if is_all(eid): if is_all(eid):
num_edges = self.number_of_edges() num_edges = self.number_of_edges()
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
num_edges = len(eid) num_edges = len(eid)
if utils.is_dict_like(h_uv): for key, val in he.items():
for key, val in h_uv.items(): nfeats = F.shape(val)[0]
assert F.shape(val)[0] == num_edges if nfeats != num_edges:
else: raise DGLError('Expect number of features to match number of edges.'
assert F.shape(h_uv)[0] == num_edges ' Got %d and %d instead.' % (nfeats, num_edges))
# set # set
if is_all(eid): if is_all(eid):
# update column # update column
if utils.is_dict_like(h_uv): for key, val in he.items():
for key, val in h_uv.items():
self._edge_frame[key] = val self._edge_frame[key] = val
else:
self._edge_frame[__REPR__] = h_uv
else: else:
# update row # update row
if utils.is_dict_like(h_uv): self._edge_frame.update_rows(eid, he, inplace=inplace)
self._edge_frame.update_rows(eid, h_uv, inplace=inplace)
else:
self._edge_frame.update_rows(eid, {__REPR__ : h_uv}, inplace=inplace)
def get_e_repr(self, u=ALL, v=ALL): def get_e_repr(self, u=ALL, v=ALL):
"""Get node(s) representation. """Get node(s) representation.
...@@ -742,7 +733,7 @@ class DGLGraph(object): ...@@ -742,7 +733,7 @@ class DGLGraph(object):
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
return self.get_e_repr_by_id(eid=eid) return self.get_e_repr_by_id(eid=eid)
def pop_e_repr(self, key=__REPR__): def pop_e_repr(self, key):
"""Get and remove the specified edge repr. """Get and remove the specified edge repr.
Parameters Parameters
...@@ -768,20 +759,14 @@ class DGLGraph(object): ...@@ -768,20 +759,14 @@ class DGLGraph(object):
Returns Returns
------- -------
dict dict
Representation dict Representation dict from feature name to feature tensor.
""" """
if len(self.edge_attr_schemes()) == 0: if len(self.edge_attr_schemes()) == 0:
return dict() return dict()
if is_all(eid): if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame[__REPR__]
else:
return dict(self._edge_frame) return dict(self._edge_frame)
else: else:
eid = utils.toindex(eid) eid = utils.toindex(eid)
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
return self._edge_frame.select_rows(eid)[__REPR__]
else:
return self._edge_frame.select_rows(eid) return self._edge_frame.select_rows(eid)
def register_edge_func(self, edge_func): def register_edge_func(self, edge_func):
...@@ -837,6 +822,8 @@ class DGLGraph(object): ...@@ -837,6 +822,8 @@ class DGLGraph(object):
def apply_nodes(self, v=ALL, apply_node_func="default"): def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations. """Apply the function on node representations.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
v : int, iterable of int, tensor, optional v : int, iterable of int, tensor, optional
...@@ -868,7 +855,7 @@ class DGLGraph(object): ...@@ -868,7 +855,7 @@ class DGLGraph(object):
# merge current node_repr with reduce output # merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr) curr_repr = utils.HybridDict(reduce_accum, curr_repr)
new_repr = apply_node_func(curr_repr) new_repr = apply_node_func(curr_repr)
if reduce_accum is not None and utils.is_dict_like(new_repr) : if reduce_accum is not None:
# merge new node_repr with reduce output # merge new node_repr with reduce output
reduce_accum.update(new_repr) reduce_accum.update(new_repr)
new_repr = reduce_accum new_repr = reduce_accum
...@@ -877,6 +864,8 @@ class DGLGraph(object): ...@@ -877,6 +864,8 @@ class DGLGraph(object):
def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None): def apply_edges(self, u=None, v=None, apply_edge_func="default", eid=None):
"""Apply the function on edge representations. """Apply the function on edge representations.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
u : optional, int, iterable of int, tensor u : optional, int, iterable of int, tensor
...@@ -893,7 +882,6 @@ class DGLGraph(object): ...@@ -893,7 +882,6 @@ class DGLGraph(object):
if not apply_edge_func: if not apply_edge_func:
# Skip none function call. # Skip none function call.
return return
if eid is None: if eid is None:
new_repr = apply_edge_func(self.get_e_repr(u, v)) new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v) self.set_e_repr(new_repr, u, v)
...@@ -914,9 +902,8 @@ class DGLGraph(object): ...@@ -914,9 +902,8 @@ class DGLGraph(object):
The message function can be any of the pre-defined functions The message function can be any of the pre-defined functions
('from_src'). ('from_src').
Currently, we require the message functions of consecutive send's and Currently, we require the message functions of consecutive send's to
send_on's to return the same keys. Otherwise the behavior will be return the same keys. Otherwise the behavior will be undefined.
undefined.
Parameters Parameters
---------- ----------
...@@ -964,10 +951,7 @@ class DGLGraph(object): ...@@ -964,10 +951,7 @@ class DGLGraph(object):
edge_reprs = self.get_e_repr_by_id(eid) edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
self._msg_graph.add_edges(u, v) self._msg_graph.add_edges(u, v)
if utils.is_dict_like(msgs):
self._msg_frame.append(msgs) self._msg_frame.append(msgs)
else:
self._msg_frame.append({__MSG__ : msgs})
# TODO(minjie): Fix these codes in next PR. # TODO(minjie): Fix these codes in next PR.
""" """
...@@ -1061,7 +1045,6 @@ class DGLGraph(object): ...@@ -1061,7 +1045,6 @@ class DGLGraph(object):
v = utils.toindex(v) v = utils.toindex(v)
u, v = utils.edge_broadcasting(u, v) u, v = utils.edge_broadcasting(u, v)
_, _, eid = self._graph.edge_ids(u, v) _, _, eid = self._graph.edge_ids(u, v)
# call the UDF # call the UDF
src_reprs = self.get_n_repr(u) src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v) dst_reprs = self.get_n_repr(v)
...@@ -1148,25 +1131,19 @@ class DGLGraph(object): ...@@ -1148,25 +1131,19 @@ class DGLGraph(object):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
new_shape = (bkt_len, deg) + msg_shape[1:] new_shape = (bkt_len, deg) + msg_shape[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
if len(in_msgs) == 1 and __MSG__ in in_msgs:
reshaped_in_msgs = _reshape_fn(in_msgs[__MSG__])
else:
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
reordered_v.append(v_bkt.tousertensor()) reordered_v.append(v_bkt.tousertensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs)) new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages # TODO(minjie): clear partial messages
self.reset_messages() self.reset_messages()
# Pack all reducer results together # Pack all reducer results together
reordered_v = F.pack(reordered_v) reordered_v = F.pack(reordered_v)
if utils.is_dict_like(new_reprs[0]):
keys = new_reprs[0].keys() keys = new_reprs[0].keys()
new_reprs = {key : F.pack([repr[key] for repr in new_reprs]) new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
for key in keys} for key in keys}
else:
new_reprs = {__REPR__ : F.pack(new_reprs)}
if v_is_all and not has_zero_degree: if v_is_all and not has_zero_degree:
# First do reorder and then replace the whole column. # First do reorder and then replace the whole column.
...@@ -1237,15 +1214,13 @@ class DGLGraph(object): ...@@ -1237,15 +1214,13 @@ class DGLGraph(object):
if executor: if executor:
new_reprs = executor.run() new_reprs = executor.run()
if not utils.is_dict_like(new_reprs):
new_reprs = {__REPR__: new_reprs}
unique_v = executor.recv_nodes unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs) self._apply_nodes(unique_v, apply_node_func, reduce_accum=new_reprs)
elif eid is not None: elif eid is not None:
_, v, _ = self._graph.find_edges(eid) _, v, _ = self._graph.find_edges(eid)
unique_v = utils.toindex(F.unique(v.tousertensor())) unique_v = utils.toindex(F.unique(v.tousertensor()))
# TODO: replace with the new DegreeBucketingScheduler # TODO(quan): replace with the new DegreeBucketingScheduler
self.send(eid=eid, message_func=message_func) self.send(eid=eid, message_func=message_func)
self.recv(unique_v, reduce_func, apply_node_func) self.recv(unique_v, reduce_func, apply_node_func)
else: else:
...@@ -1261,10 +1236,7 @@ class DGLGraph(object): ...@@ -1261,10 +1236,7 @@ class DGLGraph(object):
edge_reprs = self.get_e_repr(u, v) edge_reprs = self.get_e_repr(u, v)
msgs = message_func(src_reprs, edge_reprs) msgs = message_func(src_reprs, edge_reprs)
msg_frame = FrameRef() msg_frame = FrameRef()
if utils.is_dict_like(msgs):
msg_frame.append(msgs) msg_frame.append(msgs)
else:
msg_frame.append({__MSG__: msgs})
# recv with degree bucketing # recv with degree bucketing
executor = scheduler.get_recv_executor(graph=self, executor = scheduler.get_recv_executor(graph=self,
...@@ -1353,8 +1325,6 @@ class DGLGraph(object): ...@@ -1353,8 +1325,6 @@ class DGLGraph(object):
"update_all", self, message_func=message_func, reduce_func=reduce_func) "update_all", self, message_func=message_func, reduce_func=reduce_func)
if executor: if executor:
new_reprs = executor.run() new_reprs = executor.run()
if not utils.is_dict_like(new_reprs):
new_reprs = {__REPR__: new_reprs}
self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs) self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
else: else:
self.send(ALL, ALL, message_func) self.send(ALL, ALL, message_func)
...@@ -1387,7 +1357,7 @@ class DGLGraph(object): ...@@ -1387,7 +1357,7 @@ class DGLGraph(object):
Arguments for pre-defined iterators. Arguments for pre-defined iterators.
""" """
if isinstance(traverser, str): if isinstance(traverser, str):
# TODO Call pre-defined routine to unroll the computation. # TODO(minjie): Call pre-defined routine to unroll the computation.
raise RuntimeError('Not implemented.') raise RuntimeError('Not implemented.')
else: else:
# NOTE: the iteration can return multiple edges at each step. # NOTE: the iteration can return multiple edges at each step.
......
...@@ -3,7 +3,7 @@ from __future__ import absolute_import ...@@ -3,7 +3,7 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from .base import ALL, __MSG__, __REPR__ from .base import ALL, DGLError
from . import backend as F from . import backend as F
from .function import message as fmsg from .function import message as fmsg
from .function import reducer as fred from .function import reducer as fred
...@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph): ...@@ -111,7 +111,15 @@ def light_degree_bucketing_for_graph(graph):
class Executor(object): class Executor(object):
"""Base class for executing graph computation."""
def run(self): def run(self):
"""Run this executor.
This should return the new node features.
TODO(minjie): extend this to support computation on edges.
"""
raise NotImplementedError raise NotImplementedError
class SPMVOperator(Executor): class SPMVOperator(Executor):
...@@ -126,9 +134,6 @@ class SPMVOperator(Executor): ...@@ -126,9 +134,6 @@ class SPMVOperator(Executor):
def run(self): def run(self):
# get src col # get src col
if self.src_field is None:
srccol = self.node_repr
else:
srccol = self.node_repr[self.src_field] srccol = self.node_repr[self.src_field]
ctx = F.get_context(srccol) ctx = F.get_context(srccol)
...@@ -142,9 +147,6 @@ class SPMVOperator(Executor): ...@@ -142,9 +147,6 @@ class SPMVOperator(Executor):
dstcol = F.squeeze(dstcol) dstcol = F.squeeze(dstcol)
else: else:
dstcol = F.spmm(adjmat, srccol) dstcol = F.spmm(adjmat, srccol)
if self.dst_field is None:
return dstcol
else:
return {self.dst_field : dstcol} return {self.dst_field : dstcol}
...@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor): ...@@ -180,20 +182,14 @@ class DegreeBucketingExecutor(Executor):
msg_shape = F.shape(msg) msg_shape = F.shape(msg)
new_shape = (len(vv), deg) + msg_shape[1:] new_shape = (len(vv), deg) + msg_shape[1:]
return F.reshape(msg, new_shape) return F.reshape(msg, new_shape)
if len(in_msgs) == 1 and __MSG__ in in_msgs:
reshaped_in_msgs = _reshape_fn(in_msgs[__MSG__])
else:
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self.msg_frame.schemes)
new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs)) new_reprs.append(self.rfunc(dst_reprs, reshaped_in_msgs))
# Pack all reducer results together # Pack all reducer results together
if utils.is_dict_like(new_reprs[0]):
keys = new_reprs[0].keys() keys = new_reprs[0].keys()
new_reprs = {key : F.pack([repr[key] for repr in new_reprs]) new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
for key in keys} for key in keys}
else:
new_reprs = {__REPR__ : F.pack(new_reprs)}
return new_reprs return new_reprs
...@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -249,12 +245,6 @@ class UpdateAllExecutor(BasicExecutor):
self._graph_shape = None self._graph_shape = None
self._recv_nodes = None self._recv_nodes = None
@property
def graph_idx(self):
if self._graph_idx is None:
self._graph_idx = self.g._graph.adjacency_matrix()
return self._graph_idx
@property @property
def graph_shape(self): def graph_shape(self):
if self._graph_shape is None: if self._graph_shape is None:
...@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor): ...@@ -280,16 +270,13 @@ class UpdateAllExecutor(BasicExecutor):
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field] dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
# TODO(minjie): should not directly use _indices # TODO(minjie): should not directly use _indices
idx = self.graph_idx.get(ctx)._indices() idx = self.g.adjacency_matrix(ctx)._indices()
adjmat = F.sparse_tensor(idx, dat, self.graph_shape) adjmat = F.sparse_tensor(idx, dat, self.graph_shape)
else: else:
adjmat = self.graph_idx.get(ctx) adjmat = self.g.adjacency_matrix(ctx)
return adjmat return adjmat
...@@ -351,9 +338,6 @@ class SendRecvExecutor(BasicExecutor): ...@@ -351,9 +338,6 @@ class SendRecvExecutor(BasicExecutor):
def _adj_build_fn(self, edge_field, ctx, use_edge_feat): def _adj_build_fn(self, edge_field, ctx, use_edge_feat):
if use_edge_feat: if use_edge_feat:
if edge_field is None:
dat = self.edge_repr
else:
dat = self.edge_repr[edge_field] dat = self.edge_repr[edge_field]
dat = F.squeeze(dat) dat = F.squeeze(dat)
else: else:
...@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor): ...@@ -386,9 +370,8 @@ class BundledExecutor(BasicExecutor):
func_pairs = [] func_pairs = []
for rfn in rfunc.fn_list: for rfn in rfunc.fn_list:
mfn = out2mfunc.get(rfn.msg_field, None) mfn = out2mfunc.get(rfn.msg_field, None)
# field check if mfn is None:
assert mfn is not None, \ raise DGLError('Cannot find message field "%s".' % rfn.msg_field)
"cannot find message func for reduce func in-field {}".format(rfn.msg_field)
func_pairs.append((mfn, rfn)) func_pairs.append((mfn, rfn))
return func_pairs return func_pairs
...@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor): ...@@ -409,7 +392,6 @@ class BundledUpdateAllExecutor(BundledExecutor, UpdateAllExecutor):
self._init_state() self._init_state()
BundledExecutor.__init__(self, graph, mfunc, rfunc) BundledExecutor.__init__(self, graph, mfunc, rfunc)
class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor): class BundledSendRecvExecutor(BundledExecutor, SendRecvExecutor):
def __init__(self, graph, src, dst, mfunc, rfunc): def __init__(self, graph, src, dst, mfunc, rfunc):
self._init_state(src, dst) self._init_state(src, dst)
......
...@@ -209,14 +209,13 @@ def test_reduce_0deg(): ...@@ -209,14 +209,13 @@ def test_reduce_0deg():
g.add_edge(3, 0) g.add_edge(3, 0)
g.add_edge(4, 0) g.add_edge(4, 0)
def _message(src, edge): def _message(src, edge):
return src return {'m' : src['h']}
def _reduce(node, msgs): def _reduce(node, msgs):
assert msgs is not None return {'h' : node['h'] + msgs['m'].sum(1)}
return node + msgs.sum(1)
old_repr = th.randn(5, 5) old_repr = th.randn(5, 5)
g.set_n_repr(old_repr) g.set_n_repr({'h' : old_repr})
g.update_all(_message, _reduce) g.update_all(_message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[1:], old_repr[1:]) assert th.allclose(new_repr[1:], old_repr[1:])
assert th.allclose(new_repr[0], old_repr.sum(0)) assert th.allclose(new_repr[0], old_repr.sum(0))
...@@ -226,25 +225,25 @@ def test_pull_0deg(): ...@@ -226,25 +225,25 @@ def test_pull_0deg():
g.add_nodes(2) g.add_nodes(2)
g.add_edge(0, 1) g.add_edge(0, 1)
def _message(src, edge): def _message(src, edge):
return src return {'m' : src['h']}
def _reduce(node, msgs): def _reduce(node, msgs):
assert msgs is not None return {'h' : msgs['m'].sum(1)}
return msgs.sum(1)
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr(old_repr) g.set_n_repr({'h' : old_repr})
g.pull(0, _message, _reduce) g.pull(0, _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[1]) assert th.allclose(new_repr[1], old_repr[1])
g.pull(1, _message, _reduce) g.pull(1, _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
old_repr = th.randn(2, 5) old_repr = th.randn(2, 5)
g.set_n_repr(old_repr) g.set_n_repr({'h' : old_repr})
g.pull([0, 1], _message, _reduce) g.pull([0, 1], _message, _reduce)
new_repr = g.get_n_repr() new_repr = g.get_n_repr()['h']
assert th.allclose(new_repr[0], old_repr[0]) assert th.allclose(new_repr[0], old_repr[0])
assert th.allclose(new_repr[1], old_repr[0]) assert th.allclose(new_repr[1], old_repr[0])
......
import torch as th
from torch.autograd import Variable
import numpy as np
from dgl.graph import DGLGraph, __REPR__
D = 32
reduce_msg_shapes = set()
def check_eq(a, b):
assert a.shape == b.shape
assert th.sum(a == b) == int(np.prod(list(a.shape)))
def message_func(hu, e_uv):
assert len(hu.shape) == 2
assert hu.shape[1] == D
return hu
def reduce_func(hv, msgs):
reduce_msg_shapes.add(tuple(msgs.shape))
assert len(msgs.shape) == 3
assert msgs.shape[2] == D
return hv + th.sum(msgs, 1)
def generate_graph(grad=False):
g = DGLGraph()
g.add_nodes(10)
# create a graph where 0 is the source and 9 is the sink
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = Variable(th.randn(10, D), requires_grad=grad)
ecol = Variable(th.randn(17, D), requires_grad=grad)
g.set_n_repr(ncol)
g.set_e_repr(ecol)
return g
def test_batch_setter_getter():
def _pfc(x):
return list(x.numpy()[:,0])
g = generate_graph()
# set all nodes
g.set_n_repr(th.zeros((10, D)))
assert _pfc(g.get_n_repr()) == [0.] * 10
# pop nodes
assert _pfc(g.pop_n_repr()) == [0.] * 10
assert len(g.get_n_repr()) == 0
g.set_n_repr(th.zeros((10, D)))
# set partial nodes
u = th.tensor([1, 3, 5])
g.set_n_repr(th.ones((3, D)), u)
assert _pfc(g.get_n_repr()) == [0., 1., 0., 1., 0., 1., 0., 0., 0., 0.]
# get partial nodes
u = th.tensor([1, 2, 3])
assert _pfc(g.get_n_repr(u)) == [1., 0., 1.]
'''
s, d, eid
0, 1, 0
1, 9, 1
0, 2, 2
2, 9, 3
0, 3, 4
3, 9, 5
0, 4, 6
4, 9, 7
0, 5, 8
5, 9, 9
0, 6, 10
6, 9, 11
0, 7, 12
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
# set all edges
g.set_e_repr(th.zeros((17, D)))
assert _pfc(g.get_e_repr()) == [0.] * 17
# pop edges
assert _pfc(g.pop_e_repr()) == [0.] * 17
assert len(g.get_e_repr()) == 0
g.set_e_repr(th.zeros((17, D)))
# set partial edges (many-many)
u = th.tensor([0, 0, 2, 5, 9])
v = th.tensor([1, 3, 9, 9, 0])
g.set_e_repr(th.ones((5, D)), u, v)
truth = [0.] * 17
truth[0] = truth[4] = truth[3] = truth[9] = truth[16] = 1.
assert _pfc(g.get_e_repr()) == truth
# set partial edges (many-one)
u = th.tensor([3, 4, 6])
v = th.tensor([9])
g.set_e_repr(th.ones((3, D)), u, v)
truth[5] = truth[7] = truth[11] = 1.
assert _pfc(g.get_e_repr()) == truth
# set partial edges (one-many)
u = th.tensor([0])
v = th.tensor([4, 5, 6])
g.set_e_repr(th.ones((3, D)), u, v)
truth[6] = truth[8] = truth[10] = 1.
assert _pfc(g.get_e_repr()) == truth
# get partial edges (many-many)
u = th.tensor([0, 6, 0])
v = th.tensor([6, 9, 7])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
# get partial edges (many-one)
u = th.tensor([5, 6, 7])
v = th.tensor([9])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 0.]
# get partial edges (one-many)
u = th.tensor([0])
v = th.tensor([3, 4, 5])
assert _pfc(g.get_e_repr(u, v)) == [1., 1., 1.]
def test_batch_setter_autograd():
g = generate_graph(grad=True)
h1 = g.get_n_repr()
# partial set
v = th.tensor([1, 2, 8])
hh = Variable(th.zeros((len(v), D)), requires_grad=True)
g.set_n_repr(hh, v)
h2 = g.get_n_repr()
h2.backward(th.ones((10, D)) * 2)
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_batch_send():
g = generate_graph()
def _fmsg(hu, edge):
assert hu.shape == (5, D)
return hu
g.register_message_func(_fmsg)
# many-many send
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
# one-many send
u = th.tensor([0])
v = th.tensor([1, 2, 3, 4, 5])
g.send(u, v)
# many-one send
u = th.tensor([1, 2, 3, 4, 5])
v = th.tensor([9])
g.send(u, v)
def test_batch_recv():
g = generate_graph()
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
reduce_msg_shapes.clear()
g.send(u, v)
g.recv(th.unique(v))
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_routines():
g = generate_graph()
g.register_message_func(message_func)
g.register_reduce_func(reduce_func)
# send_and_recv
reduce_msg_shapes.clear()
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.send_and_recv(u, v)
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
# pull
v = th.tensor([1, 2, 3, 9])
reduce_msg_shapes.clear()
g.pull(v)
assert(reduce_msg_shapes == {(1, 8, D), (3, 1, D)})
reduce_msg_shapes.clear()
# push
v = th.tensor([0, 1, 2, 3])
reduce_msg_shapes.clear()
g.push(v)
assert(reduce_msg_shapes == {(1, 3, D), (8, 1, D)})
reduce_msg_shapes.clear()
# update_all
reduce_msg_shapes.clear()
g.update_all()
assert(reduce_msg_shapes == {(1, 8, D), (9, 1, D)})
reduce_msg_shapes.clear()
if __name__ == '__main__':
test_batch_setter_getter()
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_routines()
...@@ -18,8 +18,8 @@ def tree1(): ...@@ -18,8 +18,8 @@ def tree1():
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(1, 0) g.add_edge(1, 0)
g.add_edge(2, 0) g.add_edge(2, 0)
g.set_n_repr(th.Tensor([0, 1, 2, 3, 4])) g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr(th.randn(4, 10)) g.set_e_repr({'h' : th.randn(4, 10)})
return g return g
def tree2(): def tree2():
...@@ -37,17 +37,17 @@ def tree2(): ...@@ -37,17 +37,17 @@ def tree2():
g.add_edge(0, 4) g.add_edge(0, 4)
g.add_edge(4, 1) g.add_edge(4, 1)
g.add_edge(3, 1) g.add_edge(3, 1)
g.set_n_repr(th.Tensor([0, 1, 2, 3, 4])) g.set_n_repr({'h' : th.Tensor([0, 1, 2, 3, 4])})
g.set_e_repr(th.randn(4, 10)) g.set_e_repr({'h' : th.randn(4, 10)})
return g return g
def test_batch_unbatch(): def test_batch_unbatch():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
n1 = t1.get_n_repr() n1 = t1.get_n_repr()['h']
n2 = t2.get_n_repr() n2 = t2.get_n_repr()['h']
e1 = t1.get_e_repr() e1 = t1.get_e_repr()['h']
e2 = t2.get_e_repr() e2 = t2.get_e_repr()['h']
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10 assert bg.number_of_nodes() == 10
...@@ -57,10 +57,10 @@ def test_batch_unbatch(): ...@@ -57,10 +57,10 @@ def test_batch_unbatch():
assert bg.batch_num_edges == [4, 4] assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg) tt1, tt2 = dgl.unbatch(bg)
assert th.allclose(t1.get_n_repr(), tt1.get_n_repr()) assert th.allclose(t1.get_n_repr()['h'], tt1.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr(), tt1.get_e_repr()) assert th.allclose(t1.get_e_repr()['h'], tt1.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr(), tt2.get_n_repr()) assert th.allclose(t2.get_n_repr()['h'], tt2.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr(), tt2.get_e_repr()) assert th.allclose(t2.get_e_repr()['h'], tt2.get_e_repr()['h'])
def test_batch_unbatch1(): def test_batch_unbatch1():
t1 = tree1() t1 = tree1()
...@@ -74,20 +74,20 @@ def test_batch_unbatch1(): ...@@ -74,20 +74,20 @@ def test_batch_unbatch1():
assert b2.batch_num_edges == [4, 4, 4] assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2) s1, s2, s3 = dgl.unbatch(b2)
assert th.allclose(t2.get_n_repr(), s1.get_n_repr()) assert th.allclose(t2.get_n_repr()['h'], s1.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr(), s1.get_e_repr()) assert th.allclose(t2.get_e_repr()['h'], s1.get_e_repr()['h'])
assert th.allclose(t1.get_n_repr(), s2.get_n_repr()) assert th.allclose(t1.get_n_repr()['h'], s2.get_n_repr()['h'])
assert th.allclose(t1.get_e_repr(), s2.get_e_repr()) assert th.allclose(t1.get_e_repr()['h'], s2.get_e_repr()['h'])
assert th.allclose(t2.get_n_repr(), s3.get_n_repr()) assert th.allclose(t2.get_n_repr()['h'], s3.get_n_repr()['h'])
assert th.allclose(t2.get_e_repr(), s3.get_e_repr()) assert th.allclose(t2.get_e_repr()['h'], s3.get_e_repr()['h'])
def test_batch_sendrecv(): def test_batch_sendrecv():
t1 = tree1() t1 = tree1()
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src) bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1)) bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
u = [3, 4, 2 + 5, 0 + 5] u = [3, 4, 2 + 5, 0 + 5]
v = [1, 1, 4 + 5, 4 + 5] v = [1, 1, 4 + 5, 4 + 5]
...@@ -95,8 +95,8 @@ def test_batch_sendrecv(): ...@@ -95,8 +95,8 @@ def test_batch_sendrecv():
bg.recv(v) bg.recv(v)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()[1] == 7 assert t1.get_n_repr()['h'][1] == 7
assert t2.get_n_repr()[4] == 2 assert t2.get_n_repr()['h'][4] == 2
def test_batch_propagate(): def test_batch_propagate():
...@@ -104,8 +104,8 @@ def test_batch_propagate(): ...@@ -104,8 +104,8 @@ def test_batch_propagate():
t2 = tree2() t2 = tree2()
bg = dgl.batch([t1, t2]) bg = dgl.batch([t1, t2])
bg.register_message_func(lambda src, edge: src) bg.register_message_func(lambda src, edge: {'m' : src['h']})
bg.register_reduce_func(lambda node, msgs: th.sum(msgs, 1)) bg.register_reduce_func(lambda node, msgs: {'h' : th.sum(msgs['m'], 1)})
# get leaves. # get leaves.
order = [] order = []
...@@ -123,23 +123,23 @@ def test_batch_propagate(): ...@@ -123,23 +123,23 @@ def test_batch_propagate():
bg.propagate(traverser=order) bg.propagate(traverser=order)
t1, t2 = dgl.unbatch(bg) t1, t2 = dgl.unbatch(bg)
assert t1.get_n_repr()[0] == 9 assert t1.get_n_repr()['h'][0] == 9
assert t2.get_n_repr()[1] == 5 assert t2.get_n_repr()['h'][1] == 5
def test_batched_edge_ordering(): def test_batched_edge_ordering():
g1 = dgl.DGLGraph() g1 = dgl.DGLGraph()
g1.add_nodes(6) g1.add_nodes(6)
g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1]) g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
e1 = th.randn(5, 10) e1 = th.randn(5, 10)
g1.set_e_repr(e1) g1.set_e_repr({'h' : e1})
g2 = dgl.DGLGraph() g2 = dgl.DGLGraph()
g2.add_nodes(6) g2.add_nodes(6)
g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0]) g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
e2 = th.randn(6, 10) e2 = th.randn(6, 10)
g2.set_e_repr(e2) g2.set_e_repr({'h' : e2})
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
r1 = g.get_e_repr()[g.edge_id(4, 5)] r1 = g.get_e_repr()['h'][g.edge_id(4, 5)]
r2 = g1.get_e_repr()[g1.edge_id(4, 5)] r2 = g1.get_e_repr()['h'][g1.edge_id(4, 5)]
assert th.equal(r1, r2) assert th.equal(r1, r2)
def test_batch_no_edge(): def test_batch_no_edge():
......
import torch as th import torch as th
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dgl.graph import __REPR__
def generate_graph(): def generate_graph():
g = dgl.DGLGraph() g = dgl.DGLGraph()
...@@ -37,18 +36,9 @@ def generate_graph1(): ...@@ -37,18 +36,9 @@ def generate_graph1():
g.set_e_repr(h) g.set_e_repr(h)
return g return g
def reducer_msg(node, msgs):
return th.sum(msgs['m'], 1)
def reducer_out(node, msgs):
return {'h' : th.sum(msgs, 1)}
def reducer_both(node, msgs): def reducer_both(node, msgs):
return {'h' : th.sum(msgs['m'], 1)} return {'h' : th.sum(msgs['m'], 1)}
def reducer_none(node, msgs):
return th.sum(msgs, 1)
def test_copy_src(): def test_copy_src():
# copy_src with both fields # copy_src with both fields
g = generate_graph() g = generate_graph()
...@@ -58,30 +48,6 @@ def test_copy_src(): ...@@ -58,30 +48,6 @@ def test_copy_src():
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with only src field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_src(src='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_src with no src field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_src(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy src with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_src())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_copy_edge(): def test_copy_edge():
# copy_edge with both fields # copy_edge with both fields
g = generate_graph() g = generate_graph()
...@@ -91,30 +57,6 @@ def test_copy_edge(): ...@@ -91,30 +57,6 @@ def test_copy_edge():
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.])) th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with only edge field; the out field should use anonymous repr
g = generate_graph()
g.register_message_func(fn.copy_edge(edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy_edge with no edge field; should use anonymous repr
g = generate_graph1()
g.register_message_func(fn.copy_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
# copy edge with no fields;
g = generate_graph1()
g.register_message_func(fn.copy_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([10., 1., 1., 1., 1., 1., 1., 1., 1., 44.]))
def test_src_mul_edge(): def test_src_mul_edge():
# src_mul_edge with all fields # src_mul_edge with all fields
g = generate_graph() g = generate_graph()
...@@ -124,34 +66,6 @@ def test_src_mul_edge(): ...@@ -124,34 +66,6 @@ def test_src_mul_edge():
assert th.allclose(g.get_n_repr()['h'], assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.])) th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph()
g.register_message_func(fn.src_mul_edge(src='h', edge='h'))
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge(out='m'))
g.register_reduce_func(reducer_both)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_out)
g.update_all()
assert th.allclose(g.get_n_repr()['h'],
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
g = generate_graph1()
g.register_message_func(fn.src_mul_edge())
g.register_reduce_func(reducer_none)
g.update_all()
assert th.allclose(g.get_n_repr(),
th.tensor([100., 1., 1., 1., 1., 1., 1., 1., 1., 284.]))
if __name__ == '__main__': if __name__ == '__main__':
test_copy_src() test_copy_src()
test_copy_edge() test_copy_edge()
......
...@@ -5,35 +5,31 @@ import dgl ...@@ -5,35 +5,31 @@ import dgl
D = 5 D = 5
def check_eq(a, b):
return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())
def test_line_graph(): def test_line_graph():
N = 5 N = 5
G = dgl.DGLGraph(nx.star_graph(N)) G = dgl.DGLGraph(nx.star_graph(N))
G.set_e_repr(th.randn((2 * N, D))) G.set_e_repr({'h' : th.randn((2 * N, D))})
n_edges = G.number_of_edges() n_edges = G.number_of_edges()
L = G.line_graph(shared=True) L = G.line_graph(shared=True)
assert L.number_of_nodes() == 2 * N assert L.number_of_nodes() == 2 * N
L.set_n_repr(th.randn((2 * N, D))) L.set_n_repr({'h' : th.randn((2 * N, D))})
# update node features on line graph should reflect to edge features on # update node features on line graph should reflect to edge features on
# original graph. # original graph.
u = [0, 0, 2, 3] u = [0, 0, 2, 3]
v = [1, 2, 0, 0] v = [1, 2, 0, 0]
eid = G.edge_ids(u, v) eid = G.edge_ids(u, v)
L.set_n_repr(th.zeros((4, D)), eid) L.set_n_repr({'h' : th.zeros((4, D))}, eid)
assert check_eq(G.get_e_repr(u, v), th.zeros((4, D))) assert th.allclose(G.get_e_repr(u, v)['h'], th.zeros((4, D)))
# adding a new node feature on line graph should also reflect to a new # adding a new node feature on line graph should also reflect to a new
# edge feature on original graph # edge feature on original graph
data = th.randn(n_edges, D) data = th.randn(n_edges, D)
L.set_n_repr({'w': data}) L.set_n_repr({'w': data})
assert check_eq(G.get_e_repr()['w'], data) assert th.allclose(G.get_e_repr()['w'], data)
def test_no_backtracking(): def test_no_backtracking():
N = 5 N = 5
G = dgl.DGLGraph(nx.star_graph(N)) G = dgl.DGLGraph(nx.star_graph(N))
G.set_e_repr(th.randn((2 * N, D)))
L = G.line_graph(backtracking=False) L = G.line_graph(backtracking=False)
assert L.number_of_nodes() == 2 * N assert L.number_of_nodes() == 2 * N
for i in range(1, N): for i in range(1, N):
......
...@@ -22,23 +22,23 @@ def generate_graph(): ...@@ -22,23 +22,23 @@ def generate_graph():
def test_update_all(): def test_update_all():
def _test(fld): def _test(fld):
def message_func(hu, edge): def message_func(hu, edge):
return hu[fld] return {'m' : hu[fld]}
def message_func_edge(hu, edge): def message_func_edge(hu, edge):
if len(hu[fld].shape) == 1: if len(hu[fld].shape) == 1:
return hu[fld] * edge['e1'] return {'m' : hu[fld] * edge['e1']}
else: else:
return hu[fld] * edge['e2'] return {'m' : hu[fld] * edge['e2']}
def reduce_func(hv, msgs): def reduce_func(hv, msgs):
return {fld : th.sum(msgs, 1)} return {fld : th.sum(msgs['m'], 1)}
def apply_func(hu): def apply_func(hu):
return {fld : 2 * hu[fld]} return {fld : 2 * hu[fld]}
g = generate_graph() g = generate_graph()
# update all # update all
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.update_all(fn.copy_src(src=fld), fn.sum(out=fld), apply_func) g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func) g.update_all(message_func, reduce_func, apply_func)
...@@ -46,12 +46,12 @@ def test_update_all(): ...@@ -46,12 +46,12 @@ def test_update_all():
assert th.allclose(v2, v3) assert th.allclose(v2, v3)
# update all with edge weights # update all with edge weights
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1'), g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
fn.sum(out=fld), apply_func) fn.sum(msg='m', out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2'), g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
fn.sum(out=fld), apply_func) fn.sum(msg='m', out=fld), apply_func)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func) g.update_all(message_func_edge, reduce_func, apply_func)
...@@ -68,42 +68,40 @@ def test_send_and_recv(): ...@@ -68,42 +68,40 @@ def test_send_and_recv():
v = th.tensor([1, 2, 3, 9, 9, 0]) v = th.tensor([1, 2, 3, 9, 9, 0])
def _test(fld): def _test(fld):
def message_func(hu, edge): def message_func(hu, edge):
return hu[fld] return {'m' : hu[fld]}
def message_func_edge(hu, edge): def message_func_edge(hu, edge):
if len(hu[fld].shape) == 1: if len(hu[fld].shape) == 1:
return hu[fld] * edge['e1'] return {'m' : hu[fld] * edge['e1']}
else: else:
return hu[fld] * edge['e2'] return {'m' : hu[fld] * edge['e2']}
def reduce_func(hv, msgs): def reduce_func(hv, msgs):
return {fld : th.sum(msgs, 1)} return {fld : th.sum(msgs['m'], 1)}
def apply_func(hu): def apply_func(hu):
return {fld : 2 * hu[fld]} return {fld : 2 * hu[fld]}
g = generate_graph() g = generate_graph()
# send and recv # send and recv
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.copy_src(src=fld), g.send_and_recv(u, v, fn.copy_src(src=fld, out='m'),
fn.sum(out=fld), apply_func) fn.sum(msg='m', out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func, g.send_and_recv(u, v, message_func, reduce_func, apply_func)
reduce_func, apply_func)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert th.allclose(v2, v3)
# send and recv with edge weights # send and recv with edge weights
v1 = g.get_n_repr()[fld] v1 = g.get_n_repr()[fld]
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1'), g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e1', out='m'),
fn.sum(out=fld), apply_func) fn.sum(msg='m', out=fld), apply_func)
v2 = g.get_n_repr()[fld] v2 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2'), g.send_and_recv(u, v, fn.src_mul_edge(src=fld, edge='e2', out='m'),
fn.sum(out=fld), apply_func) fn.sum(msg='m', out=fld), apply_func)
v3 = g.get_n_repr()[fld] v3 = g.get_n_repr()[fld]
g.set_n_repr({fld : v1}) g.set_n_repr({fld : v1})
g.send_and_recv(u, v, message_func_edge, g.send_and_recv(u, v, message_func_edge, reduce_func, apply_func)
reduce_func, apply_func)
v4 = g.get_n_repr()[fld] v4 = g.get_n_repr()[fld]
assert th.allclose(v2, v3) assert th.allclose(v2, v3)
assert th.allclose(v3, v4) assert th.allclose(v3, v4)
...@@ -127,19 +125,19 @@ def test_update_all_multi_fn(): ...@@ -127,19 +125,19 @@ def test_update_all_multi_fn():
fld = 'f2' fld = 'f2'
# update all, mix of builtin and UDF # update all, mix of builtin and UDF
g.update_all([fn.copy_src(src=fld, out='m1'), message_func], g.update_all([fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func], [fn.sum(msg='m1', out='v1'), reduce_func],
None) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# run builtin with single message and reduce # run builtin with single message and reduce
g.update_all(fn.copy_src(src=fld), fn.sum(out='v1'), None) g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'), None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr # 1 message, 2 reduces
g.update_all(fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None) g.update_all(fn.copy_src(src=fld, out='m'), [fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')], None)
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3'] v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
...@@ -147,7 +145,7 @@ def test_update_all_multi_fn(): ...@@ -147,7 +145,7 @@ def test_update_all_multi_fn():
# update all with edge weights, 2 message, 3 reduces # update all with edge weights, 2 message, 3 reduces
g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')], g.update_all([fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')], [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
None) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
...@@ -181,20 +179,23 @@ def test_send_and_recv_multi_fn(): ...@@ -181,20 +179,23 @@ def test_send_and_recv_multi_fn():
# send and recv, mix of builtin and UDF # send and recv, mix of builtin and UDF
g.send_and_recv(u, v, g.send_and_recv(u, v,
[fn.copy_src(src=fld, out='m1'), message_func], [fn.copy_src(src=fld, out='m1'), message_func],
[fn.sum(msgs='m1', out='v1'), reduce_func], [fn.sum(msg='m1', out='v1'), reduce_func],
None) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# run builtin with single message and reduce # run builtin with single message and reduce
g.send_and_recv(u, v, fn.copy_src(src=fld), fn.sum(out='v1'), g.send_and_recv(u, v, fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out='v1'),
None) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
# 1 message, 2 reduces, using anonymous repr # 1 message, 2 reduces
g.send_and_recv(u, v, fn.copy_src(src=fld), [fn.sum(out='v2'), fn.sum(out='v3')], None) g.send_and_recv(u, v,
fn.copy_src(src=fld, out='m'),
[fn.sum(msg='m', out='v2'), fn.sum(msg='m', out='v3')],
None)
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
v3 = g.get_n_repr()['v3'] v3 = g.get_n_repr()['v3']
assert th.allclose(v1, v2) assert th.allclose(v1, v2)
...@@ -203,7 +204,7 @@ def test_send_and_recv_multi_fn(): ...@@ -203,7 +204,7 @@ def test_send_and_recv_multi_fn():
# send and recv with edge weights, 2 message, 3 reduces # send and recv with edge weights, 2 message, 3 reduces
g.send_and_recv(u, v, g.send_and_recv(u, v,
[fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')], [fn.src_mul_edge(src=fld, edge='e1', out='m1'), fn.src_mul_edge(src=fld, edge='e2', out='m2')],
[fn.sum(msgs='m1', out='v1'), fn.sum(msgs='m2', out='v2'), fn.sum(msgs='m1', out='v3')], [fn.sum(msg='m1', out='v1'), fn.sum(msg='m2', out='v2'), fn.sum(msg='m1', out='v3')],
None) None)
v1 = g.get_n_repr()['v1'] v1 = g.get_n_repr()['v1']
v2 = g.get_n_repr()['v2'] v2 = g.get_n_repr()['v2']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment