Unverified Commit 61fa3c6c authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Builtin function and API changes (#53)

* WIP: API renaming

* API rewrite and node function refactor

* builtin functions

* builtin functions tested

* fix test

* send and recv spmv test

* WIP: fix examples

* Fix examples using new APIs
parent 8c71f3f8
......@@ -30,7 +30,7 @@ class GATReduce(nn.Module):
e = F.softmax(F.leaky_relu(a), dim=0)
if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=0) # shape (D,)
return {'accum' : torch.sum(e * ft, dim=0)} # shape (D,)
class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual):
......@@ -43,8 +43,8 @@ class GATFinalize(nn.Module):
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum):
ret = accum
def forward(self, node):
ret = node['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
......
......@@ -33,7 +33,7 @@ class GATReduce(nn.Module):
e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=1) # shape (B, D)
return {'accum' : torch.sum(e * ft, dim=1)} # shape (B, D)
class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual):
......@@ -46,8 +46,8 @@ class GATFinalize(nn.Module):
if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum):
ret = accum
def forward(self, node):
ret = node['accum']
if self.residual:
if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret
......
......@@ -16,16 +16,16 @@ def gcn_msg(src, edge):
return src['h']
def gcn_reduce(node, msgs):
return sum(msgs)
return {'h' : sum(msgs)}
class NodeUpdateModule(nn.Module):
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
h = self.linear(accum)
def forward(self, node):
h = self.linear(node['h'])
if self.activation:
h = self.activation(h)
return {'h' : h}
......@@ -43,12 +43,12 @@ class GCN(nn.Module):
self.g = DGLGraph(nx_graph)
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)])
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation))
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes))
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features, train_nodes):
for n, feat in features.items():
......
......@@ -21,14 +21,14 @@ def gcn_msg(src, edge):
def gcn_reduce(node, msgs):
return torch.sum(msgs, 1)
class NodeUpdateModule(nn.Module):
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
h = self.linear(accum)
def forward(self, node):
h = self.linear(node)
if self.activation:
h = self.activation(h)
return h
......@@ -46,12 +46,12 @@ class GCN(nn.Module):
self.g = g
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)])
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation))
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes))
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr(features)
......
......@@ -12,17 +12,18 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
class NodeUpdateModule(nn.Module):
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__()
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node, accum):
h = self.linear(accum)
def forward(self, node):
h = self.linear(node)
if self.activation:
h = self.activation(h)
return h
......@@ -40,12 +41,12 @@ class GCN(nn.Module):
self.g = g
self.dropout = dropout
# input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)])
self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers
for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation))
self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes))
self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features):
self.g.set_n_repr(features)
......@@ -54,7 +55,7 @@ class GCN(nn.Module):
if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val)
self.g.update_all('from_src', 'sum', layer, batchable=True)
self.g.update_all(fn.copy_src(), fn.sum(), layer, batchable=True)
return self.g.pop_n_repr()
def main(args):
......
......@@ -21,10 +21,10 @@ import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g.cached_graph.adjmat(ctx.gpu(args.gpu))
adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g.cached_graph.adjmat(ctx.cpu())
adjmat = g.cached_graph.adjmat().get(ctx.cpu())
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
......@@ -59,6 +59,9 @@ def main(args):
args.dropout)
if cuda:
model.cuda()
zero_initializer = lambda shape : th.zeros(shape).cuda()
else:
zero_initializer = th.zeros
print(model)
optimizer = optim.Adagrad(model.parameters(),
lr=args.lr,
......@@ -68,21 +71,14 @@ def main(args):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
x = th.zeros((n, args.x_size))
h = th.zeros((n, args.h_size))
c = th.zeros((n, args.h_size))
if cuda:
batch = _batch_to_cuda(batch)
x = x.cuda()
h = h.cuda()
c = c.cuda()
if step >= 3:
t0 = time.time()
# traverse graph
giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, x, h, c, iterator=giter, train=True)
logits = model(batch, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label)
optimizer.zero_grad()
......
......@@ -52,19 +52,13 @@ class ChildSumTreeLSTMCell(nn.Module):
c_tild = th.sum(f * msgs['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild}
def update_func(self, node, accum):
def apply_func(self, node):
# equation (3), (5), (6)
if accum is None:
iou = self.W_iou(node['x'])
else:
iou = self.W_iou(node['x']) + self.U_iou(accum['h_tild'])
iou = self.W_iou(node['x']) + self.U_iou(node['h_tild'])
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7)
if accum is None:
c = i * u
else:
c = i * u + accum['c_tild']
c = i * u + node['c_tild']
# equation (8)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
......@@ -79,6 +73,7 @@ class TreeLSTM(nn.Module):
cell_type='childsum'):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.h_size = h_size
# TODO(minjie): pre-trained embedding like GLoVe
self.embedding = nn.Embedding(num_vocabs, x_size)
self.dropout = nn.Dropout(dropout)
......@@ -88,20 +83,20 @@ class TreeLSTM(nn.Module):
else:
raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, x, h, c, iterator=None, train=True):
def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
x : Tensor
Initial node input.
h : Tensor
zero_initializer : callable
Function to return zero value tensor.
h : Tensor, optional
Initial hidden state.
c : Tensor
c : Tensor, optional
Initial cell state.
iterator : graph iterator
iterator : graph iterator, optional
External iterator on graph.
Returns
......@@ -110,21 +105,29 @@ class TreeLSTM(nn.Module):
The prediction of each node.
"""
g = batch.graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_update_func(self.cell.update_func, batchable=True)
g.register_apply_node_func(self.cell.apply_func, batchable=True)
# feed embedding
embeds = self.embedding(batch.wordid)
x = zero_initializer((n, self.x_size))
x = x.index_copy(0, batch.nid_with_word, embeds)
g.set_n_repr({'x' : x, 'h' : h, 'c' : c})
if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
if c is None:
c = zero_initializer((n, self.h_size))
c_tild = zero_initializer((n, self.h_size))
g.set_n_repr({'x' : x, 'h' : h, 'c' : c, 'h_tild' : h_tild, 'c_tild' : c_tild})
# TODO(minjie): potential bottleneck
if iterator is None:
for frontier in topological_traverse(g):
#print('frontier', frontier)
g.update_to(frontier)
g.pull(frontier)
else:
for frontier in iterator:
g.update_to(frontier)
g.pull(frontier)
# compute logits
h = g.pop_n_repr('h')
h = self.dropout(h)
......
......@@ -63,6 +63,7 @@ zeros = th.zeros
spmm = th.spmm
sort = th.sort
arange = th.arange
mul = th.mul
def to_context(x, ctx):
if ctx is None:
......
......@@ -2,5 +2,9 @@
# A special argument for selecting all nodes/edges.
ALL = "__ALL__"
def is_all(arg):
return isinstance(arg, str) and arg == ALL
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
"""Built-in functors."""
from __future__ import absolute_import
import dgl.backend as F
def message_from_src(src, edge):
return src
def reduce_sum(node, msgs):
if isinstance(msgs, list):
if isinstance(msgs[0], dict):
return {k : sum(m[k] for m in msgs) for k in msgs[0].keys()}
else:
return sum(msgs)
else:
return F.sum(msgs, 1)
def reduce_max(node, msgs):
if isinstance(msgs, list):
return max(msgs)
else:
return F.max(msgs, 1)
......@@ -118,8 +118,8 @@ class CachedGraph:
dst = utils.toindex(dst)
return src, dst
@utils.ctx_cached_member
def adjmat(self, ctx):
@utils.cached_member
def adjmat(self):
"""Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension
......@@ -134,8 +134,7 @@ class CachedGraph:
n = self._graph.vcount()
dat = F.ones((len(elist),))
mat = F.sparse_tensor(idx, dat, [n, n])
mat = F.to_context(mat, ctx)
return mat
return utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
def freeze(self):
self._freeze = True
......
......@@ -22,7 +22,6 @@ _urls = {
class CitationGraphDataset(object):
def __init__(self, name):
self.name = name
self.mode = mode
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path)
......
......@@ -163,6 +163,10 @@ class FrameRef(MutableMapping):
def update_rows(self, query, other):
rowids = self._getrowid(query)
for key, col in other.items():
if key not in self:
# add new column
tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col)
idx = rowids.totensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col)
......
from .message import *
from .reducer import *
"""Built-in message function."""
from __future__ import absolute_import
import operator
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object):
def __call__(self, src, edge):
raise NotImplementedError
def name(self):
raise NotImplementedError
class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list):
self.fn_list = fn_list
def __call__(self, src, edge):
ret = None
for fn in self.fn_list:
msg = fn(src, edge)
if ret is None:
ret = msg
else:
try:
ret.update(msg)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" message functions. Please specify out_field"
" for the builtin message function.")
return ret
def name(self):
return "bundled"
class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None):
self.mul_op = mul_op
self.src_field = src_field
self.edge_field = edge_field
self.out_field = out_field
def __call__(self, src, edge):
if self.src_field is not None:
src = src[self.src_field]
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None):
self.src_field = src_field
self.out_field = out_field
def __call__(self, src, edge):
if self.src_field is not None:
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "copy_src"
class CopyEdgeMessageFunction(MessageFunction):
def __init__(self, edge_field=None, out_field=None):
self.edge_field = edge_field
self.out_field = out_field
def __call__(self, src, edge):
if self.edge_field is not None:
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None):
"""TODO(minjie): docstring """
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None):
"""TODO(minjie): docstring """
return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None):
"""TODO(minjie): docstring """
return CopyEdgeMessageFunction(edge, out)
"""Built-in reducer function."""
from __future__ import absolute_import
import dgl.backend as F
__all__ = ["ReduceFunction", "sum", "max"]
class ReduceFunction(object):
def __call__(self, node, msgs):
raise NotImplementedError
def name(self):
raise NotImplementedError
class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list):
self.fn_list = fn_list
def __call__(self, node, msgs):
ret = None
for fn in self.fn_list:
rpr = fn(node, msgs)
if ret is None:
ret = rpr
else:
try:
ret.update(rpr)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" reduce functions. Please specify out_field"
" for the builtin reduce function.")
return ret
def name(self):
return "bundled"
class SumReducerFunction(ReduceFunction):
def __init__(self, batch_sum_op, nonbatch_sum_op, msg_field=None, out_field=None):
self.batch_sum_op = batch_sum_op
self.nonbatch_sum_op = nonbatch_sum_op
self.msg_field = msg_field
self.out_field = out_field
def __call__(self, node, msgs):
if isinstance(msgs, list):
if self.msg_field is None:
ret = self.nonbatch_sum_op(msgs)
else:
ret = self.nonbatch_sum_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_sum_op(msgs, 1)
else:
ret = self.batch_sum_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "sum"
_python_sum = sum
def sum(msgs=None, out=None):
return SumReducerFunction(F.sum, _python_sum, msgs, out)
_python_max = max
def max(msgs=None, out=None):
return SumReducerFunction(F.max, _python_max, msgs, out)
......@@ -6,10 +6,9 @@ import networkx as nx
from networkx.classes.digraph import DiGraph
import dgl
from dgl.base import ALL, is_all
from dgl.base import ALL, is_all, __MSG__, __REPR__
import dgl.backend as F
from dgl.backend import Tensor
import dgl.builtin as builtin
from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context
from dgl.frame import FrameRef, merge_frames
......@@ -17,9 +16,6 @@ from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler
import dgl.utils as utils
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs.
......@@ -58,10 +54,11 @@ class DGLGraph(DiGraph):
# other class members
self._msg_graph = None
self._msg_frame = FrameRef()
self._message_func = None
self._reduce_func = None
self._update_func = None
self._edge_func = None
self._message_func = (None, None)
self._reduce_func = (None, None)
self._edge_func = (None, None)
self._apply_node_func = (None, None)
self._apply_edge_func = (None, None)
def node_attr_schemes(self):
return self._node_frame.schemes
......@@ -281,33 +278,33 @@ class DGLGraph(DiGraph):
else:
return self._edge_frame.select_rows(eid)
def register_message_func(self,
message_func,
def register_edge_func(self,
edge_func,
batchable=False):
"""Register global message function.
"""Register global edge update function.
Parameters
----------
message_func : callable
edge_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._message_func = (message_func, batchable)
self._edge_func = (edge_func, batchable)
def register_edge_func(self,
edge_func,
def register_message_func(self,
message_func,
batchable=False):
"""Register global edge update function.
"""Register global message function.
Parameters
----------
edge_func : callable
message_func : callable
Message function on the edge.
batchable : bool
Whether the provided message function allows batch computing.
"""
self._edge_func = (edge_func, batchable)
self._message_func = (message_func, batchable)
def register_reduce_func(self,
reduce_func,
......@@ -323,21 +320,94 @@ class DGLGraph(DiGraph):
"""
self._reduce_func = (reduce_func, batchable)
def register_update_func(self,
update_func,
def register_apply_node_func(self,
apply_node_func,
batchable=False):
"""Register global node update function.
"""Register global node apply function.
Parameters
----------
update_func : callable
Update function on the node.
apply_node_func : callable
Apply function on the node.
batchable : bool
Whether the provided update function allows batch computing.
Whether the provided function allows batch computing.
"""
self._update_func = (update_func, batchable)
self._apply_node_func = (apply_node_func, batchable)
def sendto(self, u, v, message_func=None, batchable=False):
def register_apply_edge_func(self,
apply_edge_func,
batchable=False):
"""Register global edge apply function.
Parameters
----------
apply_edge_func : callable
Apply function on the edge.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_edge_func = (apply_edge_func, batchable)
def apply_nodes(self, v, apply_node_func="default", batchable=False):
"""Apply the function on node representations.
Parameters
----------
v : int, iterable of int, tensor
The node id(s).
apply_node_func : callable
The apply node function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func
if not apply_node_func:
# Skip none function call.
return
if batchable:
new_repr = apply_node_func(self.get_n_repr(v))
self.set_n_repr(new_repr, v)
else:
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False):
"""Apply the function on edge representations.
Parameters
----------
u : int, iterable of int, tensor
The src node id(s).
v : int, iterable of int, tensor
The dst node id(s).
apply_edge_func : callable
The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func
if not apply_edge_func:
# Skip none function call.
return
if batchable:
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
if is_all(u) == is_all(v):
u, v = zip(*self.edges)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def send(self, u, v, message_func="default", batchable=False):
"""Trigger the message function on edge u->v
The message function should be compatible with following signature:
......@@ -356,33 +426,31 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
message_func : str or callable
message_func : callable
The message function.
batchable : bool
Whether the function allows batched computation.
"""
if message_func is None:
if message_func == "default":
message_func, batchable = self._message_func
assert message_func is not None
if batchable:
self._batch_sendto(u, v, message_func)
self._batch_send(u, v, message_func)
else:
self._nonbatch_sendto(u, v, message_func)
self._nonbatch_send(u, v, message_func)
def _nonbatch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
def _nonbatch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else:
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]),
ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret
def _batch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
def _batch_send(self, u, v, message_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
self.msg_graph.add_edges(u, v)
......@@ -405,7 +473,7 @@ class DGLGraph(DiGraph):
else:
self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u, v, edge_func=None, batchable=False):
def update_edge(self, u, v, edge_func="default", batchable=False):
"""Update representation on edge u->v
The edge function should be compatible with following signature:
......@@ -422,12 +490,12 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
edge_func : str or callable
edge_func : callable
The update function.
batchable : bool
Whether the function allows batched computation.
"""
if edge_func is None:
if edge_func == "default":
edge_func, batchable = self._edge_func
assert edge_func is not None
if batchable:
......@@ -470,142 +538,93 @@ class DGLGraph(DiGraph):
def recv(self,
u,
reduce_func=None,
update_func=None,
reduce_func="default",
apply_node_func="default",
batchable=False):
"""Receive in-coming messages and update representation on node u.
"""Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors
of node u. If no message is found from the predecessors, reduce function
will be skipped and a None type will be provided as the reduced messages for
the update function.
will be skipped.
The reduce function should be compatible with following signature:
(node_reprs, batched_messages) -> reduced_messages
(node_reprs, batched_messages) -> node_reprs
It computes the reduced edge representations using the representations
It computes the new node representations using the representations
of the in-coming edges (the same concept as messages).
The reduce function can be any of the pre-defined functions ('sum',
'max'). If built-in function is used, computation will be performed
efficiently (using generic-SPMV kernels).
The reduce function can also be pre-defined functions.
The update function should be compatible with following signature:
An optinoal apply_node function could be specified and should follow following
signature:
(node_reprs, reduced_messages) -> node_reprs
node_reprs -> node_reprs
It computes the new node representations using the representations
of the in-coming edges (the same concept as messages) and the node
itself. All node_reprs and edge_reprs are dictionaries.
All node_reprs and edge_reprs support tensor and dictionary types.
Parameters
----------
u : node, container or tensor
The node to be updated.
reduce_func : str or callable
reduce_func : callable
The reduce function.
update_func : str or callable
apply_node_func : callable, optional
The update function.
batchable : bool
batchable : bool, optional
Whether the reduce and update function allows batched computation.
"""
if reduce_func is None:
if reduce_func == "default":
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert reduce_func is not None
assert update_func is not None
if batchable:
self._batch_recv(u, reduce_func, update_func)
self._batch_recv(u, reduce_func)
else:
self._nonbatch_recv(u, reduce_func, update_func)
self._nonbatch_recv(u, reduce_func)
# optional apply nodes
self.apply_nodes(u, apply_node_func, batchable)
def _nonbatch_recv(self, u, reduce_func, update_func):
f_reduce = _get_reduce_func(reduce_func)
def _nonbatch_recv(self, u, reduce_func):
if is_all(u):
u = list(range(0, self.number_of_nodes()))
else:
u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) == 0:
msgs_reduced = None
else:
msgs_reduced = f_reduce(_get_repr(self.nodes[uu]), msgs_batch)
# update phase
ret = update_func(_get_repr(self.nodes[uu]), msgs_reduced)
_set_repr(self.nodes[uu], ret)
def _batch_recv(self, v, reduce_func, update_func):
if len(v) == 0:
# no vertex to be triggered.
return
null_v, reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func)
if all_reduced_msgs is None:
# no message; only do recv.
if is_all(v):
self.set_n_repr(update_func(self.get_n_repr(), None))
else:
self.set_n_repr(update_func(self.get_n_repr(v), None), v)
else:
# Compute new node repr for nodes with no in-coming messages.
if len(null_v) == 0:
new_null_ns = None
else:
new_null_ns = update_func(self.get_n_repr(null_v), None)
# Read the node states in the degree-bucketing order.
if len(reordered_v) == 0:
new_reordered_ns = None
else:
new_reordered_ns = update_func(self.get_n_repr(reordered_v), all_reduced_msgs)
v_tensor = utils.pack2(null_v.totensor(), reordered_v.totensor())
new_ns = utils.pack2(new_null_ns, new_reordered_ns)
if len(msgs_batch) != 0:
new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
_set_repr(self.nodes[uu], new_repr)
if is_all(v):
# First do reorder and then replace the whole column.
_, indices = F.sort(v_tensor)
indices = utils.toindex(indices)
# TODO(minjie): following code should be included in Frame somehow.
if utils.is_dict_like(new_ns):
for key, val in new_ns.items():
idx = indices.totensor(F.get_context(val))
self._node_frame[key] = F.gather_row(val, idx)
else:
idx = indices.totensor(F.get_context(new_ns))
self._node_frame[__REPR__] = F.gather_row(new_ns, idx)
else:
# Use setter to do reorder.
self.set_n_repr(new_ns, v_tensor)
def _batch_reduce(self, v, reduce_func):
def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0:
# no message has ever been sent
return None, None, None
return
if is_all(v):
v_is_all = is_all(v)
if v_is_all:
v = list(range(self.number_of_nodes()))
# sanity checks
if len(v) == 0:
# no vertex to be triggered.
return
v = utils.toindex(v)
f_reduce = _get_reduce_func(reduce_func)
# degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
null_v_bucket = None
non_null_v_buckets = []
reduced_msgs = []
for deg, v_bkt in zip(degrees, v_buckets):
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
if degrees == [0]:
# no message has been sent to the specified node
return
reordered_v = []
new_reprs = []
has_zero_degree = False
for deg, v_bkt in zip(degrees, v_buckets):
if deg == 0:
assert null_v_bucket is None
null_v_bucket = v_bkt
# no need to trigger reduce func for zero-degree nodes
has_zero_degree = True
continue
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids)
......@@ -619,39 +638,36 @@ class DGLGraph(DiGraph):
else:
reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
non_null_v_buckets.append(v_bkt)
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs))
if len(reduced_msgs) == 0:
# no message has been sent to the specified node
return None, None, None
reordered_v.append(v_bkt.totensor())
new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
# TODO: clear partial messages
self.clear_messages()
# Read the node states in the degree-bucketing order.
null_v = utils.toindex(null_v_bucket or [])
reordered_v = utils.toindex(
F.pack([v_bkt.totensor() for v_bkt in non_null_v_buckets])
if len(non_null_v_buckets) > 0 else []
)
# Pack all reduced msgs together
if utils.is_dict_like(reduced_msgs[0]):
keys = reduced_msgs[0].keys()
all_reduced_msgs = {
key : F.pack([msg[key] for msg in reduced_msgs])
# Pack all reducer results together
reordered_v = F.pack(reordered_v)
if utils.is_dict_like(new_reprs[0]):
keys = new_reprs[0].keys()
new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
for key in keys}
else:
all_reduced_msgs = F.pack(reduced_msgs)
new_reprs = {__REPR__ : F.pack(new_reprs)}
return null_v, reordered_v, all_reduced_msgs
if v_is_all and not has_zero_degree:
# First do reorder and then replace the whole column.
_, indices = F.sort(reordered_v)
indices = utils.toindex(indices)
new_reprs = utils.reorder(new_reprs, indices)
self.set_n_repr(new_reprs)
else:
# Use setter to do reorder.
self.set_n_repr(new_reprs, reordered_v)
def update_by_edge(self,
def send_and_recv(self,
u, v,
message_func=None,
reduce_func=None,
update_func=None,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
"""Trigger the message function on u->v and update v.
......@@ -661,100 +677,50 @@ class DGLGraph(DiGraph):
The source node(s).
v : node, container or tensor
The destination node(s).
message_func : str or callable
message_func : callable
The message function.
reduce_func : str or callable
reduce_func : callable
The reduce function.
update_func : str or callable
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func is None:
u = utils.toindex(u)
v = utils.toindex(v)
if len(u) == 0:
# no edges to be triggered
assert len(v) == 0
return
unique_v = utils.toindex(F.unique(v.totensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default":
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
self._batch_update_by_edge(
u, v, message_func, reduce_func, update_func)
else:
self._nonbatch_update_by_edge(
u, v, message_func, reduce_func, update_func)
def _nonbatch_update_by_edge(
self,
u, v,
message_func,
reduce_func,
update_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
if batchable:
executor = scheduler.get_executor(
'send_and_recv', self, src=u, dst=v,
message_func=message_func, reduce_func=reduce_func)
else:
u = utils.toindex(u)
v = utils.toindex(v)
self._nonbatch_sendto(u, v, message_func)
dst = set()
for uu, vv in utils.edge_iter(u, v):
dst.add(vv)
self._nonbatch_recv(list(dst), reduce_func, update_func)
executor = None
def _batch_update_by_edge(
self,
u, v,
message_func,
reduce_func,
update_func):
if len(u) == 0:
# no message
assert len(v) == 0
elif is_all(u) and is_all(v):
self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v
u = utils.toindex(u)
v = utils.toindex(v)
# TODO(minjie): broadcasting is optional for many-one input.
u, v = utils.edge_broadcasting(u, v)
# relabel destination nodes.
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor()
# TODO(minjie): should not directly use []
new_v = old2new[v]
# create adj mat
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
dat = F.ones((len(u),))
n = self.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(ctx_adjmat.get(F.get_context(col)), col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr(new2old)
new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old)
if executor:
executor.run()
else:
u = utils.toindex(u)
v = utils.toindex(v)
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v.totensor())
self._batch_recv(unique_v, reduce_func, update_func)
self.send(u, v, message_func, batchable=batchable)
self.recv(unique_v, reduce_func, None, batchable=batchable)
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
def update_to(self,
def pull(self,
v,
message_func=None,
reduce_func=None,
update_func=None,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
"""Pull messages from the node's predecessors and then update it.
......@@ -762,45 +728,29 @@ class DGLGraph(DiGraph):
----------
v : node, container or tensor
The node to be updated.
message_func : str or callable
message_func : callable
The message function.
reduce_func : str or callable
reduce_func : callable
The reduce function.
update_func : str or callable
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func is None:
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
v = utils.toindex(v)
uu, vv, orphan = self.cached_graph.in_edges(v)
self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func)
# trigger update function for nodes that have no incoming messages.
self._batch_recv(orphan, reduce_func, update_func)
else:
v = utils.toindex(v)
for vv in utils.node_iter(v):
assert vv in self.nodes
uu = list(self.pred[vv])
if len(uu) > 0:
self._nonbatch_sendto(uu, vv, message_func)
self._nonbatch_recv(vv, reduce_func, update_func)
if len(v) == 0:
return
uu, vv, _ = self.cached_graph.in_edges(v)
self.send_and_recv(uu, vv, message_func, reduce_func,
apply_node_func=None, batchable=batchable)
unique_v = F.unique(v.totensor())
self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
def update_from(self,
def push(self,
u,
message_func=None,
reduce_func=None,
update_func=None,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
"""Send message from the node to its successors and update them.
......@@ -808,91 +758,65 @@ class DGLGraph(DiGraph):
----------
u : node, container or tensor
The node that sends out messages.
message_func : str or callable
message_func : callable
The message function.
reduce_func : str or callable
reduce_func : callable
The reduce function.
update_func : str or callable
apply_node_func : callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func is None:
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
u = utils.toindex(u)
if len(u) == 0:
return
uu, vv, _ = self.cached_graph.out_edges(u)
self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func)
else:
u = utils.toindex(u)
for uu in utils.node_iter(u):
assert uu in self.nodes
for v in self.succ[uu]:
self._nonbatch_update_by_edge(uu, v,
message_func, reduce_func, update_func)
self.send_and_recv(uu, vv, message_func,
reduce_func, apply_node_func, batchable=batchable)
def update_all(self,
message_func=None,
reduce_func=None,
update_func=None,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False):
"""Send messages through all the edges and update all nodes.
Parameters
----------
message_func : str or callable
message_func : callable
The message function.
reduce_func : str or callable
reduce_func : callable
The reduce function.
update_func : str or callable
apply_node_func : callable, optional
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
"""
if message_func is None:
if message_func == "default":
message_func, batchable = self._message_func
if reduce_func is None:
reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
if reduce_func == "default":
reduce_func, _ = self._reduce_func
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
if message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
adjmat = self.cached_graph.adjmat(F.get_context(col))
reduced_msgs[key] = F.spmm(adjmat, col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr()
self.set_n_repr(update_func(node_repr, reduced_msgs))
executor = scheduler.get_executor(
"update_all", self, message_func=message_func, reduce_func=reduce_func)
else:
self._batch_sendto(ALL, ALL, message_func)
self._batch_recv(ALL, reduce_func, update_func)
executor = None
if executor:
executor.run()
else:
u, v = zip(*self.edges)
u = list(u)
v = list(v)
self._nonbatch_sendto(u, v, message_func)
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func)
self.send(ALL, ALL, message_func, batchable=batchable)
self.recv(ALL, reduce_func, None, batchable=batchable)
self.apply_nodes(ALL, apply_node_func, batchable=batchable)
def propagate(self,
iterator='bfs',
message_func=None,
reduce_func=None,
update_func=None,
message_func="default",
reduce_func="default",
apply_node_func="default",
batchable=False,
**kwargs):
"""Propagate messages and update nodes using iterator.
......@@ -910,7 +834,7 @@ class DGLGraph(DiGraph):
The message function.
reduce_func : str or callable
The reduce function.
update_func : str or callable
apply_node_func : str or callable
The update function.
batchable : bool
Whether the reduce and update function allows batched computation.
......@@ -925,8 +849,8 @@ class DGLGraph(DiGraph):
else:
# NOTE: the iteration can return multiple edges at each step.
for u, v in iterator:
self.update_by_edge(u, v,
message_func, reduce_func, update_func, batchable)
self.send_and_recv(u, v,
message_func, reduce_func, apply_node_func, batchable)
def subgraph(self, nodes):
"""Generate the subgraph among the given nodes.
......@@ -1077,25 +1001,3 @@ def _set_repr(attr_dict, attr):
attr_dict.update(attr)
else:
attr_dict[__REPR__] = attr
def _get_reduce_func(reduce_func):
if isinstance(reduce_func, str):
# built-in reduce func
if reduce_func == 'sum':
return builtin.reduce_sum
elif reduce_func == 'max':
return builtin.reduce_max
else:
raise ValueError(
"Unknown built-in reduce function: %s" % reduce_func)
return reduce_func
def _get_message_func(message_func):
if isinstance(message_func, str):
# built-in message func
if message_func == 'from_src':
return builtin.message_from_src
else:
raise ValueError(
"Unknown built-in message function: %s" % message_func)
return message_func
......@@ -6,6 +6,9 @@ Code: https://github.com/tkipf/gcn
GCN with SPMV specialization.
"""
import torch.nn as nn
import dgl
import dgl.function as fn
from dgl.base import ALL, is_all
class NodeUpdateModule(nn.Module):
......@@ -15,13 +18,8 @@ class NodeUpdateModule(nn.Module):
self.activation = activation
self.attribute = None
def set_attribute_to_update(self, attribute):
self.attribute = attribute
def forward(self, node, accum, attribute=None):
if self.attribute:
accum = accum[self.attribute]
h = self.linear(accum)
def forward(self, node):
h = self.linear(node['accum'])
if self.activation:
h = self.activation(h)
if self.attribute:
......@@ -41,9 +39,16 @@ class GCN(nn.Module):
self.update_func = NodeUpdateModule(in_feats, out_feats, activation)
def forward(self, g, u=ALL, v=ALL, attribute=None):
self.update_func.set_attribute_to_update(attribute)
if is_all(u) and is_all(v):
g.update_all('from_src', 'sum', self.update_func, batchable=True)
g.update_all(fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
else:
g.update_by_edge(u, v, 'from_src', 'sum', self.update_func, batchable=True)
g.send_and_recv(u, v,
fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
g.pop_n_repr('accum')
return g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment