"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "2acafdafd91d4ad3df7079d63fc51f3f3c00813a"
Unverified Commit 61fa3c6c authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Builtin function and API changes (#53)

* WIP: API renaming

* API rewrite and node function refactor

* builtin functions

* builtin functions tested

* fix test

* send and recv spmv test

* WIP: fix examples

* Fix examples using new APIs
parent 8c71f3f8
...@@ -30,7 +30,7 @@ class GATReduce(nn.Module): ...@@ -30,7 +30,7 @@ class GATReduce(nn.Module):
e = F.softmax(F.leaky_relu(a), dim=0) e = F.softmax(F.leaky_relu(a), dim=0)
if self.attn_drop != 0.0: if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop) e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=0) # shape (D,) return {'accum' : torch.sum(e * ft, dim=0)} # shape (D,)
class GATFinalize(nn.Module): class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual): def __init__(self, headid, indim, hiddendim, activation, residual):
...@@ -43,8 +43,8 @@ class GATFinalize(nn.Module): ...@@ -43,8 +43,8 @@ class GATFinalize(nn.Module):
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim) self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum): def forward(self, node):
ret = accum ret = node['accum']
if self.residual: if self.residual:
if self.residual_fc is not None: if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret ret = self.residual_fc(node['h']) + ret
......
...@@ -33,7 +33,7 @@ class GATReduce(nn.Module): ...@@ -33,7 +33,7 @@ class GATReduce(nn.Module):
e = F.softmax(F.leaky_relu(a), dim=1) e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop != 0.0: if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop) e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=1) # shape (B, D) return {'accum' : torch.sum(e * ft, dim=1)} # shape (B, D)
class GATFinalize(nn.Module): class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual): def __init__(self, headid, indim, hiddendim, activation, residual):
...@@ -46,8 +46,8 @@ class GATFinalize(nn.Module): ...@@ -46,8 +46,8 @@ class GATFinalize(nn.Module):
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim) self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum): def forward(self, node):
ret = accum ret = node['accum']
if self.residual: if self.residual:
if self.residual_fc is not None: if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret ret = self.residual_fc(node['h']) + ret
......
...@@ -16,16 +16,16 @@ def gcn_msg(src, edge): ...@@ -16,16 +16,16 @@ def gcn_msg(src, edge):
return src['h'] return src['h']
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
return sum(msgs) return {'h' : sum(msgs)}
class NodeUpdateModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node, accum): def forward(self, node):
h = self.linear(accum) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h' : h} return {'h' : h}
...@@ -43,12 +43,12 @@ class GCN(nn.Module): ...@@ -43,12 +43,12 @@ class GCN(nn.Module):
self.g = DGLGraph(nx_graph) self.g = DGLGraph(nx_graph)
self.dropout = dropout self.dropout = dropout
# input layer # input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features, train_nodes): def forward(self, features, train_nodes):
for n, feat in features.items(): for n, feat in features.items():
......
...@@ -21,14 +21,14 @@ def gcn_msg(src, edge): ...@@ -21,14 +21,14 @@ def gcn_msg(src, edge):
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
return torch.sum(msgs, 1) return torch.sum(msgs, 1)
class NodeUpdateModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node, accum): def forward(self, node):
h = self.linear(accum) h = self.linear(node)
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return h
...@@ -46,12 +46,12 @@ class GCN(nn.Module): ...@@ -46,12 +46,12 @@ class GCN(nn.Module):
self.g = g self.g = g
self.dropout = dropout self.dropout = dropout
# input layer # input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr(features)
......
...@@ -12,17 +12,18 @@ import torch ...@@ -12,17 +12,18 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
class NodeUpdateModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node, accum): def forward(self, node):
h = self.linear(accum) h = self.linear(node)
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return h
...@@ -40,12 +41,12 @@ class GCN(nn.Module): ...@@ -40,12 +41,12 @@ class GCN(nn.Module):
self.g = g self.g = g
self.dropout = dropout self.dropout = dropout
# input layer # input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr(features)
...@@ -54,7 +55,7 @@ class GCN(nn.Module): ...@@ -54,7 +55,7 @@ class GCN(nn.Module):
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val) self.g.set_n_repr(val)
self.g.update_all('from_src', 'sum', layer, batchable=True) self.g.update_all(fn.copy_src(), fn.sum(), layer, batchable=True)
return self.g.pop_n_repr() return self.g.pop_n_repr()
def main(args): def main(args):
......
...@@ -21,10 +21,10 @@ import dgl.context as ctx ...@@ -21,10 +21,10 @@ import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args): def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes() n = g.number_of_nodes()
if cuda: if cuda:
adjmat = g.cached_graph.adjmat(ctx.gpu(args.gpu)) adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu))
mask = th.ones((n, 1)).cuda() mask = th.ones((n, 1)).cuda()
else: else:
adjmat = g.cached_graph.adjmat(ctx.cpu()) adjmat = g.cached_graph.adjmat().get(ctx.cpu())
mask = th.ones((n, 1)) mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask) degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.: while th.sum(mask) != 0.:
...@@ -59,6 +59,9 @@ def main(args): ...@@ -59,6 +59,9 @@ def main(args):
args.dropout) args.dropout)
if cuda: if cuda:
model.cuda() model.cuda()
zero_initializer = lambda shape : th.zeros(shape).cuda()
else:
zero_initializer = th.zeros
print(model) print(model)
optimizer = optim.Adagrad(model.parameters(), optimizer = optim.Adagrad(model.parameters(),
lr=args.lr, lr=args.lr,
...@@ -68,21 +71,14 @@ def main(args): ...@@ -68,21 +71,14 @@ def main(args):
t_epoch = time.time() t_epoch = time.time()
for step, batch in enumerate(train_loader): for step, batch in enumerate(train_loader):
g = batch.graph g = batch.graph
n = g.number_of_nodes()
x = th.zeros((n, args.x_size))
h = th.zeros((n, args.h_size))
c = th.zeros((n, args.h_size))
if cuda: if cuda:
batch = _batch_to_cuda(batch) batch = _batch_to_cuda(batch)
x = x.cuda()
h = h.cuda()
c = c.cuda()
if step >= 3: if step >= 3:
t0 = time.time() t0 = time.time()
# traverse graph # traverse graph
giter = list(tensor_topo_traverse(g, False, args)) giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, x, h, c, iterator=giter, train=True) logits = model(batch, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label) loss = F.nll_loss(logp, batch.label)
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -51,20 +51,14 @@ class ChildSumTreeLSTMCell(nn.Module): ...@@ -51,20 +51,14 @@ class ChildSumTreeLSTMCell(nn.Module):
# equation (7) second term # equation (7) second term
c_tild = th.sum(f * msgs['c'], 1) c_tild = th.sum(f * msgs['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild} return {'h_tild' : h_tild, 'c_tild' : c_tild}
def update_func(self, node, accum): def apply_func(self, node):
# equation (3), (5), (6) # equation (3), (5), (6)
if accum is None: iou = self.W_iou(node['x']) + self.U_iou(node['h_tild'])
iou = self.W_iou(node['x'])
else:
iou = self.W_iou(node['x']) + self.U_iou(accum['h_tild'])
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7) # equation (7)
if accum is None: c = i * u + node['c_tild']
c = i * u
else:
c = i * u + accum['c_tild']
# equation (8) # equation (8)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
...@@ -79,6 +73,7 @@ class TreeLSTM(nn.Module): ...@@ -79,6 +73,7 @@ class TreeLSTM(nn.Module):
cell_type='childsum'): cell_type='childsum'):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
self.h_size = h_size
# TODO(minjie): pre-trained embedding like GLoVe # TODO(minjie): pre-trained embedding like GLoVe
self.embedding = nn.Embedding(num_vocabs, x_size) self.embedding = nn.Embedding(num_vocabs, x_size)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -88,20 +83,20 @@ class TreeLSTM(nn.Module): ...@@ -88,20 +83,20 @@ class TreeLSTM(nn.Module):
else: else:
raise RuntimeError('Unknown cell type:', cell_type) raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, x, h, c, iterator=None, train=True): def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
---------- ----------
batch : dgl.data.SSTBatch batch : dgl.data.SSTBatch
The data batch. The data batch.
x : Tensor zero_initializer : callable
Initial node input. Function to return zero value tensor.
h : Tensor h : Tensor, optional
Initial hidden state. Initial hidden state.
c : Tensor c : Tensor, optional
Initial cell state. Initial cell state.
iterator : graph iterator iterator : graph iterator, optional
External iterator on graph. External iterator on graph.
Returns Returns
...@@ -110,21 +105,29 @@ class TreeLSTM(nn.Module): ...@@ -110,21 +105,29 @@ class TreeLSTM(nn.Module):
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = batch.graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True) g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True) g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_update_func(self.cell.update_func, batchable=True) g.register_apply_node_func(self.cell.apply_func, batchable=True)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid) embeds = self.embedding(batch.wordid)
x = zero_initializer((n, self.x_size))
x = x.index_copy(0, batch.nid_with_word, embeds) x = x.index_copy(0, batch.nid_with_word, embeds)
g.set_n_repr({'x' : x, 'h' : h, 'c' : c}) if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
if c is None:
c = zero_initializer((n, self.h_size))
c_tild = zero_initializer((n, self.h_size))
g.set_n_repr({'x' : x, 'h' : h, 'c' : c, 'h_tild' : h_tild, 'c_tild' : c_tild})
# TODO(minjie): potential bottleneck # TODO(minjie): potential bottleneck
if iterator is None: if iterator is None:
for frontier in topological_traverse(g): for frontier in topological_traverse(g):
#print('frontier', frontier) #print('frontier', frontier)
g.update_to(frontier) g.pull(frontier)
else: else:
for frontier in iterator: for frontier in iterator:
g.update_to(frontier) g.pull(frontier)
# compute logits # compute logits
h = g.pop_n_repr('h') h = g.pop_n_repr('h')
h = self.dropout(h) h = self.dropout(h)
......
...@@ -63,6 +63,7 @@ zeros = th.zeros ...@@ -63,6 +63,7 @@ zeros = th.zeros
spmm = th.spmm spmm = th.spmm
sort = th.sort sort = th.sort
arange = th.arange arange = th.arange
mul = th.mul
def to_context(x, ctx): def to_context(x, ctx):
if ctx is None: if ctx is None:
......
...@@ -2,5 +2,9 @@ ...@@ -2,5 +2,9 @@
# A special argument for selecting all nodes/edges. # A special argument for selecting all nodes/edges.
ALL = "__ALL__" ALL = "__ALL__"
def is_all(arg): def is_all(arg):
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
"""Built-in functors."""
from __future__ import absolute_import
import dgl.backend as F
def message_from_src(src, edge):
return src
def reduce_sum(node, msgs):
if isinstance(msgs, list):
if isinstance(msgs[0], dict):
return {k : sum(m[k] for m in msgs) for k in msgs[0].keys()}
else:
return sum(msgs)
else:
return F.sum(msgs, 1)
def reduce_max(node, msgs):
if isinstance(msgs, list):
return max(msgs)
else:
return F.max(msgs, 1)
...@@ -118,8 +118,8 @@ class CachedGraph: ...@@ -118,8 +118,8 @@ class CachedGraph:
dst = utils.toindex(dst) dst = utils.toindex(dst)
return src, dst return src, dst
@utils.ctx_cached_member @utils.cached_member
def adjmat(self, ctx): def adjmat(self):
"""Return a sparse adjacency matrix. """Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension The row dimension represents the dst nodes; the column dimension
...@@ -134,8 +134,7 @@ class CachedGraph: ...@@ -134,8 +134,7 @@ class CachedGraph:
n = self._graph.vcount() n = self._graph.vcount()
dat = F.ones((len(elist),)) dat = F.ones((len(elist),))
mat = F.sparse_tensor(idx, dat, [n, n]) mat = F.sparse_tensor(idx, dat, [n, n])
mat = F.to_context(mat, ctx) return utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return mat
def freeze(self): def freeze(self):
self._freeze = True self._freeze = True
......
...@@ -22,7 +22,6 @@ _urls = { ...@@ -22,7 +22,6 @@ _urls = {
class CitationGraphDataset(object): class CitationGraphDataset(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.mode = mode
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name) self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path) download(_urls[name], path=self.zip_file_path)
......
...@@ -163,6 +163,10 @@ class FrameRef(MutableMapping): ...@@ -163,6 +163,10 @@ class FrameRef(MutableMapping):
def update_rows(self, query, other): def update_rows(self, query, other):
rowids = self._getrowid(query) rowids = self._getrowid(query)
for key, col in other.items(): for key, col in other.items():
if key not in self:
# add new column
tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col)
idx = rowids.totensor(F.get_context(self._frame[key])) idx = rowids.totensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col) self._frame[key] = F.scatter_row(self._frame[key], idx, col)
......
from .message import *
from .reducer import *
"""Built-in message function."""
from __future__ import absolute_import
import operator
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object):
def __call__(self, src, edge):
raise NotImplementedError
def name(self):
raise NotImplementedError
class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list):
self.fn_list = fn_list
def __call__(self, src, edge):
ret = None
for fn in self.fn_list:
msg = fn(src, edge)
if ret is None:
ret = msg
else:
try:
ret.update(msg)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" message functions. Please specify out_field"
" for the builtin message function.")
return ret
def name(self):
return "bundled"
class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None):
self.mul_op = mul_op
self.src_field = src_field
self.edge_field = edge_field
self.out_field = out_field
def __call__(self, src, edge):
if self.src_field is not None:
src = src[self.src_field]
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None):
self.src_field = src_field
self.out_field = out_field
def __call__(self, src, edge):
if self.src_field is not None:
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "copy_src"
class CopyEdgeMessageFunction(MessageFunction):
def __init__(self, edge_field=None, out_field=None):
self.edge_field = edge_field
self.out_field = out_field
def __call__(self, src, edge):
if self.edge_field is not None:
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None):
"""TODO(minjie): docstring """
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None):
"""TODO(minjie): docstring """
return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None):
"""TODO(minjie): docstring """
return CopyEdgeMessageFunction(edge, out)
"""Built-in reducer function."""
from __future__ import absolute_import
import dgl.backend as F
__all__ = ["ReduceFunction", "sum", "max"]
class ReduceFunction(object):
def __call__(self, node, msgs):
raise NotImplementedError
def name(self):
raise NotImplementedError
class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list):
self.fn_list = fn_list
def __call__(self, node, msgs):
ret = None
for fn in self.fn_list:
rpr = fn(node, msgs)
if ret is None:
ret = rpr
else:
try:
ret.update(rpr)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" reduce functions. Please specify out_field"
" for the builtin reduce function.")
return ret
def name(self):
return "bundled"
class SumReducerFunction(ReduceFunction):
def __init__(self, batch_sum_op, nonbatch_sum_op, msg_field=None, out_field=None):
self.batch_sum_op = batch_sum_op
self.nonbatch_sum_op = nonbatch_sum_op
self.msg_field = msg_field
self.out_field = out_field
def __call__(self, node, msgs):
if isinstance(msgs, list):
if self.msg_field is None:
ret = self.nonbatch_sum_op(msgs)
else:
ret = self.nonbatch_sum_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_sum_op(msgs, 1)
else:
ret = self.batch_sum_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "sum"
_python_sum = sum
def sum(msgs=None, out=None):
return SumReducerFunction(F.sum, _python_sum, msgs, out)
_python_max = max
def max(msgs=None, out=None):
return SumReducerFunction(F.max, _python_max, msgs, out)
...@@ -6,10 +6,9 @@ import networkx as nx ...@@ -6,10 +6,9 @@ import networkx as nx
from networkx.classes.digraph import DiGraph from networkx.classes.digraph import DiGraph
import dgl import dgl
from dgl.base import ALL, is_all from dgl.base import ALL, is_all, __MSG__, __REPR__
import dgl.backend as F import dgl.backend as F
from dgl.backend import Tensor from dgl.backend import Tensor
import dgl.builtin as builtin
from dgl.cached_graph import CachedGraph, create_cached_graph from dgl.cached_graph import CachedGraph, create_cached_graph
import dgl.context as context import dgl.context as context
from dgl.frame import FrameRef, merge_frames from dgl.frame import FrameRef, merge_frames
...@@ -17,9 +16,6 @@ from dgl.nx_adapt import nx_init ...@@ -17,9 +16,6 @@ from dgl.nx_adapt import nx_init
import dgl.scheduler as scheduler import dgl.scheduler as scheduler
import dgl.utils as utils import dgl.utils as utils
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
class DGLGraph(DiGraph): class DGLGraph(DiGraph):
"""Base graph class specialized for neural networks on graphs. """Base graph class specialized for neural networks on graphs.
...@@ -58,10 +54,11 @@ class DGLGraph(DiGraph): ...@@ -58,10 +54,11 @@ class DGLGraph(DiGraph):
# other class members # other class members
self._msg_graph = None self._msg_graph = None
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
self._message_func = None self._message_func = (None, None)
self._reduce_func = None self._reduce_func = (None, None)
self._update_func = None self._edge_func = (None, None)
self._edge_func = None self._apply_node_func = (None, None)
self._apply_edge_func = (None, None)
def node_attr_schemes(self): def node_attr_schemes(self):
return self._node_frame.schemes return self._node_frame.schemes
...@@ -281,33 +278,33 @@ class DGLGraph(DiGraph): ...@@ -281,33 +278,33 @@ class DGLGraph(DiGraph):
else: else:
return self._edge_frame.select_rows(eid) return self._edge_frame.select_rows(eid)
def register_message_func(self, def register_edge_func(self,
message_func, edge_func,
batchable=False): batchable=False):
"""Register global message function. """Register global edge update function.
Parameters Parameters
---------- ----------
message_func : callable edge_func : callable
Message function on the edge. Message function on the edge.
batchable : bool batchable : bool
Whether the provided message function allows batch computing. Whether the provided message function allows batch computing.
""" """
self._message_func = (message_func, batchable) self._edge_func = (edge_func, batchable)
def register_edge_func(self, def register_message_func(self,
edge_func, message_func,
batchable=False): batchable=False):
"""Register global edge update function. """Register global message function.
Parameters Parameters
---------- ----------
edge_func : callable message_func : callable
Message function on the edge. Message function on the edge.
batchable : bool batchable : bool
Whether the provided message function allows batch computing. Whether the provided message function allows batch computing.
""" """
self._edge_func = (edge_func, batchable) self._message_func = (message_func, batchable)
def register_reduce_func(self, def register_reduce_func(self,
reduce_func, reduce_func,
...@@ -323,21 +320,94 @@ class DGLGraph(DiGraph): ...@@ -323,21 +320,94 @@ class DGLGraph(DiGraph):
""" """
self._reduce_func = (reduce_func, batchable) self._reduce_func = (reduce_func, batchable)
def register_update_func(self, def register_apply_node_func(self,
update_func, apply_node_func,
batchable=False): batchable=False):
"""Register global node update function. """Register global node apply function.
Parameters
----------
apply_node_func : callable
Apply function on the node.
batchable : bool
Whether the provided function allows batch computing.
"""
self._apply_node_func = (apply_node_func, batchable)
def register_apply_edge_func(self,
apply_edge_func,
batchable=False):
"""Register global edge apply function.
Parameters Parameters
---------- ----------
update_func : callable apply_edge_func : callable
Update function on the node. Apply function on the edge.
batchable : bool batchable : bool
Whether the provided update function allows batch computing. Whether the provided function allows batch computing.
""" """
self._update_func = (update_func, batchable) self._apply_edge_func = (apply_edge_func, batchable)
def sendto(self, u, v, message_func=None, batchable=False): def apply_nodes(self, v, apply_node_func="default", batchable=False):
"""Apply the function on node representations.
Parameters
----------
v : int, iterable of int, tensor
The node id(s).
apply_node_func : callable
The apply node function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_node_func == "default":
apply_node_func, batchable = self._apply_node_func
if not apply_node_func:
# Skip none function call.
return
if batchable:
new_repr = apply_node_func(self.get_n_repr(v))
self.set_n_repr(new_repr, v)
else:
if is_all(v):
v = self.nodes()
v = utils.toindex(v)
for vv in utils.node_iter(v):
ret = apply_node_func(_get_repr(self.nodes[vv]))
_set_repr(self.nodes[vv], ret)
def apply_edges(self, u, v, apply_edge_func="default", batchable=False):
"""Apply the function on edge representations.
Parameters
----------
u : int, iterable of int, tensor
The src node id(s).
v : int, iterable of int, tensor
The dst node id(s).
apply_edge_func : callable
The apply edge function.
batchable : bool
Whether the provided function allows batch computing.
"""
if apply_edge_func == "default":
apply_edge_func, batchable = self._apply_edge_func
if not apply_edge_func:
# Skip none function call.
return
if batchable:
new_repr = apply_edge_func(self.get_e_repr(u, v))
self.set_e_repr(new_repr, u, v)
else:
if is_all(u) == is_all(v):
u, v = zip(*self.edges)
u = utils.toindex(u)
v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v):
ret = apply_edge_func(_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)
def send(self, u, v, message_func="default", batchable=False):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v
The message function should be compatible with following signature: The message function should be compatible with following signature:
...@@ -356,33 +426,31 @@ class DGLGraph(DiGraph): ...@@ -356,33 +426,31 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
message_func : str or callable message_func : callable
The message function. The message function.
batchable : bool batchable : bool
Whether the function allows batched computation. Whether the function allows batched computation.
""" """
if message_func is None: if message_func == "default":
message_func, batchable = self._message_func message_func, batchable = self._message_func
assert message_func is not None assert message_func is not None
if batchable: if batchable:
self._batch_sendto(u, v, message_func) self._batch_send(u, v, message_func)
else: else:
self._nonbatch_sendto(u, v, message_func) self._nonbatch_send(u, v, message_func)
def _nonbatch_sendto(self, u, v, message_func): def _nonbatch_send(self, u, v, message_func):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
else: else:
u = utils.toindex(u) u = utils.toindex(u)
v = utils.toindex(v) v = utils.toindex(v)
for uu, vv in utils.edge_iter(u, v): for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]), ret = message_func(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv])) _get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret self.edges[uu, vv][__MSG__] = ret
def _batch_sendto(self, u, v, message_func): def _batch_send(self, u, v, message_func):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
self.msg_graph.add_edges(u, v) self.msg_graph.add_edges(u, v)
...@@ -405,7 +473,7 @@ class DGLGraph(DiGraph): ...@@ -405,7 +473,7 @@ class DGLGraph(DiGraph):
else: else:
self._msg_frame.append({__MSG__ : msgs}) self._msg_frame.append({__MSG__ : msgs})
def update_edge(self, u, v, edge_func=None, batchable=False): def update_edge(self, u, v, edge_func="default", batchable=False):
"""Update representation on edge u->v """Update representation on edge u->v
The edge function should be compatible with following signature: The edge function should be compatible with following signature:
...@@ -422,12 +490,12 @@ class DGLGraph(DiGraph): ...@@ -422,12 +490,12 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
edge_func : str or callable edge_func : callable
The update function. The update function.
batchable : bool batchable : bool
Whether the function allows batched computation. Whether the function allows batched computation.
""" """
if edge_func is None: if edge_func == "default":
edge_func, batchable = self._edge_func edge_func, batchable = self._edge_func
assert edge_func is not None assert edge_func is not None
if batchable: if batchable:
...@@ -470,142 +538,93 @@ class DGLGraph(DiGraph): ...@@ -470,142 +538,93 @@ class DGLGraph(DiGraph):
def recv(self, def recv(self,
u, u,
reduce_func=None, reduce_func="default",
update_func=None, apply_node_func="default",
batchable=False): batchable=False):
"""Receive in-coming messages and update representation on node u. """Receive and reduce in-coming messages and update representation on node u.
It computes the new node state using the messages sent from the predecessors It computes the new node state using the messages sent from the predecessors
of node u. If no message is found from the predecessors, reduce function of node u. If no message is found from the predecessors, reduce function
will be skipped and a None type will be provided as the reduced messages for will be skipped.
the update function.
The reduce function should be compatible with following signature: The reduce function should be compatible with following signature:
(node_reprs, batched_messages) -> reduced_messages (node_reprs, batched_messages) -> node_reprs
It computes the reduced edge representations using the representations It computes the new node representations using the representations
of the in-coming edges (the same concept as messages). of the in-coming edges (the same concept as messages).
The reduce function can be any of the pre-defined functions ('sum', The reduce function can also be pre-defined functions.
'max'). If built-in function is used, computation will be performed
efficiently (using generic-SPMV kernels).
The update function should be compatible with following signature: An optinoal apply_node function could be specified and should follow following
signature:
(node_reprs, reduced_messages) -> node_reprs node_reprs -> node_reprs
It computes the new node representations using the representations All node_reprs and edge_reprs support tensor and dictionary types.
of the in-coming edges (the same concept as messages) and the node
itself. All node_reprs and edge_reprs are dictionaries.
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
The node to be updated. The node to be updated.
reduce_func : str or callable reduce_func : callable
The reduce function. The reduce function.
update_func : str or callable apply_node_func : callable, optional
The update function. The update function.
batchable : bool batchable : bool, optional
Whether the reduce and update function allows batched computation. Whether the reduce and update function allows batched computation.
""" """
if reduce_func is None: if reduce_func == "default":
reduce_func, batchable = self._reduce_func reduce_func, batchable = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert reduce_func is not None assert reduce_func is not None
assert update_func is not None
if batchable: if batchable:
self._batch_recv(u, reduce_func, update_func) self._batch_recv(u, reduce_func)
else: else:
self._nonbatch_recv(u, reduce_func, update_func) self._nonbatch_recv(u, reduce_func)
# optional apply nodes
self.apply_nodes(u, apply_node_func, batchable)
def _nonbatch_recv(self, u, reduce_func, update_func): def _nonbatch_recv(self, u, reduce_func):
f_reduce = _get_reduce_func(reduce_func)
if is_all(u): if is_all(u):
u = list(range(0, self.number_of_nodes())) u = list(range(0, self.number_of_nodes()))
else: else:
u = utils.toindex(u) u = utils.toindex(u)
for i, uu in enumerate(utils.node_iter(u)): for i, uu in enumerate(utils.node_iter(u)):
# reduce phase # reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__) msgs_batch = [self.edges[vv, uu].pop(__MSG__)
for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]] for vv in self.pred[uu] if __MSG__ in self.edges[vv, uu]]
if len(msgs_batch) == 0: if len(msgs_batch) != 0:
msgs_reduced = None new_repr = reduce_func(_get_repr(self.nodes[uu]), msgs_batch)
else: _set_repr(self.nodes[uu], new_repr)
msgs_reduced = f_reduce(_get_repr(self.nodes[uu]), msgs_batch)
# update phase
ret = update_func(_get_repr(self.nodes[uu]), msgs_reduced)
_set_repr(self.nodes[uu], ret)
def _batch_recv(self, v, reduce_func, update_func):
if len(v) == 0:
# no vertex to be triggered.
return
null_v, reordered_v, all_reduced_msgs = self._batch_reduce(v, reduce_func)
if all_reduced_msgs is None:
# no message; only do recv.
if is_all(v):
self.set_n_repr(update_func(self.get_n_repr(), None))
else:
self.set_n_repr(update_func(self.get_n_repr(v), None), v)
else:
# Compute new node repr for nodes with no in-coming messages.
if len(null_v) == 0:
new_null_ns = None
else:
new_null_ns = update_func(self.get_n_repr(null_v), None)
# Read the node states in the degree-bucketing order.
if len(reordered_v) == 0:
new_reordered_ns = None
else:
new_reordered_ns = update_func(self.get_n_repr(reordered_v), all_reduced_msgs)
v_tensor = utils.pack2(null_v.totensor(), reordered_v.totensor())
new_ns = utils.pack2(new_null_ns, new_reordered_ns)
if is_all(v): def _batch_recv(self, v, reduce_func):
# First do reorder and then replace the whole column.
_, indices = F.sort(v_tensor)
indices = utils.toindex(indices)
# TODO(minjie): following code should be included in Frame somehow.
if utils.is_dict_like(new_ns):
for key, val in new_ns.items():
idx = indices.totensor(F.get_context(val))
self._node_frame[key] = F.gather_row(val, idx)
else:
idx = indices.totensor(F.get_context(new_ns))
self._node_frame[__REPR__] = F.gather_row(new_ns, idx)
else:
# Use setter to do reorder.
self.set_n_repr(new_ns, v_tensor)
def _batch_reduce(self, v, reduce_func):
if self._msg_frame.num_rows == 0: if self._msg_frame.num_rows == 0:
# no message has ever been sent # no message has ever been sent
return None, None, None return
if is_all(v): v_is_all = is_all(v)
if v_is_all:
v = list(range(self.number_of_nodes())) v = list(range(self.number_of_nodes()))
if len(v) == 0:
# sanity checks # no vertex to be triggered.
return
v = utils.toindex(v) v = utils.toindex(v)
f_reduce = _get_reduce_func(reduce_func)
# degree bucketing # degree bucketing
degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v) degrees, v_buckets = scheduler.degree_bucketing(self.msg_graph, v)
null_v_bucket = None if degrees == [0]:
non_null_v_buckets = [] # no message has been sent to the specified node
reduced_msgs = [] return
for deg, v_bkt in zip(degrees, v_buckets):
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
reordered_v = []
new_reprs = []
has_zero_degree = False
for deg, v_bkt in zip(degrees, v_buckets):
if deg == 0: if deg == 0:
assert null_v_bucket is None # no need to trigger reduce func for zero-degree nodes
null_v_bucket = v_bkt has_zero_degree = True
continue continue
bkt_len = len(v_bkt)
dst_reprs = self.get_n_repr(v_bkt)
uu, vv, _ = self.msg_graph.in_edges(v_bkt) uu, vv, _ = self.msg_graph.in_edges(v_bkt)
in_msg_ids = self.msg_graph.get_edge_id(uu, vv) in_msg_ids = self.msg_graph.get_edge_id(uu, vv)
in_msgs = self._msg_frame.select_rows(in_msg_ids) in_msgs = self._msg_frame.select_rows(in_msg_ids)
...@@ -619,40 +638,37 @@ class DGLGraph(DiGraph): ...@@ -619,40 +638,37 @@ class DGLGraph(DiGraph):
else: else:
reshaped_in_msgs = utils.LazyDict( reshaped_in_msgs = utils.LazyDict(
lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes) lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)
non_null_v_buckets.append(v_bkt) reordered_v.append(v_bkt.totensor())
reduced_msgs.append(f_reduce(dst_reprs, reshaped_in_msgs)) new_reprs.append(reduce_func(dst_reprs, reshaped_in_msgs))
if len(reduced_msgs) == 0:
# no message has been sent to the specified node
return None, None, None
# TODO: clear partial messages # TODO: clear partial messages
self.clear_messages() self.clear_messages()
# Read the node states in the degree-bucketing order. # Pack all reducer results together
null_v = utils.toindex(null_v_bucket or []) reordered_v = F.pack(reordered_v)
reordered_v = utils.toindex( if utils.is_dict_like(new_reprs[0]):
F.pack([v_bkt.totensor() for v_bkt in non_null_v_buckets]) keys = new_reprs[0].keys()
if len(non_null_v_buckets) > 0 else [] new_reprs = {key : F.pack([repr[key] for repr in new_reprs])
) for key in keys}
else:
# Pack all reduced msgs together new_reprs = {__REPR__ : F.pack(new_reprs)}
if utils.is_dict_like(reduced_msgs[0]):
keys = reduced_msgs[0].keys() if v_is_all and not has_zero_degree:
all_reduced_msgs = { # First do reorder and then replace the whole column.
key : F.pack([msg[key] for msg in reduced_msgs]) _, indices = F.sort(reordered_v)
for key in keys} indices = utils.toindex(indices)
else: new_reprs = utils.reorder(new_reprs, indices)
all_reduced_msgs = F.pack(reduced_msgs) self.set_n_repr(new_reprs)
else:
return null_v, reordered_v, all_reduced_msgs # Use setter to do reorder.
self.set_n_repr(new_reprs, reordered_v)
def update_by_edge(self,
u, v, def send_and_recv(self,
message_func=None, u, v,
reduce_func=None, message_func="default",
update_func=None, reduce_func="default",
batchable=False): apply_node_func="default",
batchable=False):
"""Trigger the message function on u->v and update v. """Trigger the message function on u->v and update v.
Parameters Parameters
...@@ -661,238 +677,146 @@ class DGLGraph(DiGraph): ...@@ -661,238 +677,146 @@ class DGLGraph(DiGraph):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
message_func : str or callable message_func : callable
The message function. The message function.
reduce_func : str or callable reduce_func : callable
The reduce function. The reduce function.
update_func : str or callable apply_node_func : callable, optional
The update function. The update function.
batchable : bool batchable : bool
Whether the reduce and update function allows batched computation. Whether the reduce and update function allows batched computation.
""" """
if message_func is None: u = utils.toindex(u)
v = utils.toindex(v)
if len(u) == 0:
# no edges to be triggered
assert len(v) == 0
return
unique_v = utils.toindex(F.unique(v.totensor()))
# TODO(minjie): better way to figure out `batchable` flag
if message_func == "default":
message_func, batchable = self._message_func message_func, batchable = self._message_func
if reduce_func is None: if reduce_func == "default":
reduce_func, batchable = self._reduce_func reduce_func, _ = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
assert update_func is not None
if batchable: if batchable:
self._batch_update_by_edge( executor = scheduler.get_executor(
u, v, message_func, reduce_func, update_func) 'send_and_recv', self, src=u, dst=v,
else: message_func=message_func, reduce_func=reduce_func)
self._nonbatch_update_by_edge(
u, v, message_func, reduce_func, update_func)
def _nonbatch_update_by_edge(
self,
u, v,
message_func,
reduce_func,
update_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
else: else:
u = utils.toindex(u) executor = None
v = utils.toindex(v)
self._nonbatch_sendto(u, v, message_func) if executor:
dst = set() executor.run()
for uu, vv in utils.edge_iter(u, v):
dst.add(vv)
self._nonbatch_recv(list(dst), reduce_func, update_func)
def _batch_update_by_edge(
self,
u, v,
message_func,
reduce_func,
update_func):
if len(u) == 0:
# no message
assert len(v) == 0
elif is_all(u) and is_all(v):
self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v
u = utils.toindex(u)
v = utils.toindex(v)
# TODO(minjie): broadcasting is optional for many-one input.
u, v = utils.edge_broadcasting(u, v)
# relabel destination nodes.
new2old, old2new = utils.build_relabel_map(v)
u = u.totensor()
v = v.totensor()
# TODO(minjie): should not directly use []
new_v = old2new[v]
# create adj mat
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
dat = F.ones((len(u),))
n = self.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(ctx_adjmat.get(F.get_context(col)), col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr(new2old)
new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old)
else: else:
u = utils.toindex(u) self.send(u, v, message_func, batchable=batchable)
v = utils.toindex(v) self.recv(unique_v, reduce_func, None, batchable=batchable)
self._batch_sendto(u, v, message_func) self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
unique_v = F.unique(v.totensor())
self._batch_recv(unique_v, reduce_func, update_func) def pull(self,
v,
def update_to(self, message_func="default",
v, reduce_func="default",
message_func=None, apply_node_func="default",
reduce_func=None, batchable=False):
update_func=None,
batchable=False):
"""Pull messages from the node's predecessors and then update it. """Pull messages from the node's predecessors and then update it.
Parameters Parameters
---------- ----------
v : node, container or tensor v : node, container or tensor
The node to be updated. The node to be updated.
message_func : str or callable message_func : callable
The message function. The message function.
reduce_func : str or callable reduce_func : callable
The reduce function. The reduce function.
update_func : str or callable apply_node_func : callable, optional
The update function. The update function.
batchable : bool batchable : bool
Whether the reduce and update function allows batched computation. Whether the reduce and update function allows batched computation.
""" """
if message_func is None: v = utils.toindex(v)
message_func, batchable = self._message_func if len(v) == 0:
if reduce_func is None: return
reduce_func, batchable = self._reduce_func uu, vv, _ = self.cached_graph.in_edges(v)
if update_func is None: self.send_and_recv(uu, vv, message_func, reduce_func,
update_func, batchable = self._update_func apply_node_func=None, batchable=batchable)
assert message_func is not None unique_v = F.unique(v.totensor())
assert reduce_func is not None self.apply_nodes(unique_v, apply_node_func, batchable=batchable)
assert update_func is not None
if batchable: def push(self,
v = utils.toindex(v) u,
uu, vv, orphan = self.cached_graph.in_edges(v) message_func="default",
self._batch_update_by_edge(uu, vv, message_func, reduce_func="default",
reduce_func, update_func) apply_node_func="default",
# trigger update function for nodes that have no incoming messages. batchable=False):
self._batch_recv(orphan, reduce_func, update_func)
else:
v = utils.toindex(v)
for vv in utils.node_iter(v):
assert vv in self.nodes
uu = list(self.pred[vv])
if len(uu) > 0:
self._nonbatch_sendto(uu, vv, message_func)
self._nonbatch_recv(vv, reduce_func, update_func)
def update_from(self,
u,
message_func=None,
reduce_func=None,
update_func=None,
batchable=False):
"""Send message from the node to its successors and update them. """Send message from the node to its successors and update them.
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
The node that sends out messages. The node that sends out messages.
message_func : str or callable message_func : callable
The message function. The message function.
reduce_func : str or callable reduce_func : callable
The reduce function. The reduce function.
update_func : str or callable apply_node_func : callable
The update function. The update function.
batchable : bool batchable : bool
Whether the reduce and update function allows batched computation. Whether the reduce and update function allows batched computation.
""" """
if message_func is None: u = utils.toindex(u)
message_func, batchable = self._message_func if len(u) == 0:
if reduce_func is None: return
reduce_func, batchable = self._reduce_func uu, vv, _ = self.cached_graph.out_edges(u)
if update_func is None: self.send_and_recv(uu, vv, message_func,
update_func, batchable = self._update_func reduce_func, apply_node_func, batchable=batchable)
assert message_func is not None
assert reduce_func is not None
assert update_func is not None
if batchable:
u = utils.toindex(u)
uu, vv, _ = self.cached_graph.out_edges(u)
self._batch_update_by_edge(uu, vv, message_func,
reduce_func, update_func)
else:
u = utils.toindex(u)
for uu in utils.node_iter(u):
assert uu in self.nodes
for v in self.succ[uu]:
self._nonbatch_update_by_edge(uu, v,
message_func, reduce_func, update_func)
def update_all(self, def update_all(self,
message_func=None, message_func="default",
reduce_func=None, reduce_func="default",
update_func=None, apply_node_func="default",
batchable=False): batchable=False):
"""Send messages through all the edges and update all nodes. """Send messages through all the edges and update all nodes.
Parameters Parameters
---------- ----------
message_func : str or callable message_func : callable
The message function. The message function.
reduce_func : str or callable reduce_func : callable
The reduce function. The reduce function.
update_func : str or callable apply_node_func : callable, optional
The update function. The update function.
batchable : bool batchable : bool
Whether the reduce and update function allows batched computation. Whether the reduce and update function allows batched computation.
""" """
if message_func is None: if message_func == "default":
message_func, batchable = self._message_func message_func, batchable = self._message_func
if reduce_func is None: if reduce_func == "default":
reduce_func, batchable = self._reduce_func reduce_func, _ = self._reduce_func
if update_func is None:
update_func, batchable = self._update_func
assert message_func is not None assert message_func is not None
assert reduce_func is not None assert reduce_func is not None
assert update_func is not None
if batchable: if batchable:
if message_func == 'from_src' and reduce_func == 'sum': executor = scheduler.get_executor(
# TODO(minjie): use lazy dict for reduced_msgs "update_all", self, message_func=message_func, reduce_func=reduce_func)
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
adjmat = self.cached_graph.adjmat(F.get_context(col))
reduced_msgs[key] = F.spmm(adjmat, col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr()
self.set_n_repr(update_func(node_repr, reduced_msgs))
else:
self._batch_sendto(ALL, ALL, message_func)
self._batch_recv(ALL, reduce_func, update_func)
else: else:
u, v = zip(*self.edges) executor = None
u = list(u)
v = list(v) if executor:
self._nonbatch_sendto(u, v, message_func) executor.run()
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func) else:
self.send(ALL, ALL, message_func, batchable=batchable)
self.recv(ALL, reduce_func, None, batchable=batchable)
self.apply_nodes(ALL, apply_node_func, batchable=batchable)
def propagate(self, def propagate(self,
iterator='bfs', iterator='bfs',
message_func=None, message_func="default",
reduce_func=None, reduce_func="default",
update_func=None, apply_node_func="default",
batchable=False, batchable=False,
**kwargs): **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using iterator.
...@@ -910,7 +834,7 @@ class DGLGraph(DiGraph): ...@@ -910,7 +834,7 @@ class DGLGraph(DiGraph):
The message function. The message function.
reduce_func : str or callable reduce_func : str or callable
The reduce function. The reduce function.
update_func : str or callable apply_node_func : str or callable
The update function. The update function.
batchable : bool batchable : bool
Whether the reduce and update function allows batched computation. Whether the reduce and update function allows batched computation.
...@@ -925,8 +849,8 @@ class DGLGraph(DiGraph): ...@@ -925,8 +849,8 @@ class DGLGraph(DiGraph):
else: else:
# NOTE: the iteration can return multiple edges at each step. # NOTE: the iteration can return multiple edges at each step.
for u, v in iterator: for u, v in iterator:
self.update_by_edge(u, v, self.send_and_recv(u, v,
message_func, reduce_func, update_func, batchable) message_func, reduce_func, apply_node_func, batchable)
def subgraph(self, nodes): def subgraph(self, nodes):
"""Generate the subgraph among the given nodes. """Generate the subgraph among the given nodes.
...@@ -1077,25 +1001,3 @@ def _set_repr(attr_dict, attr): ...@@ -1077,25 +1001,3 @@ def _set_repr(attr_dict, attr):
attr_dict.update(attr) attr_dict.update(attr)
else: else:
attr_dict[__REPR__] = attr attr_dict[__REPR__] = attr
def _get_reduce_func(reduce_func):
if isinstance(reduce_func, str):
# built-in reduce func
if reduce_func == 'sum':
return builtin.reduce_sum
elif reduce_func == 'max':
return builtin.reduce_max
else:
raise ValueError(
"Unknown built-in reduce function: %s" % reduce_func)
return reduce_func
def _get_message_func(message_func):
if isinstance(message_func, str):
# built-in message func
if message_func == 'from_src':
return builtin.message_from_src
else:
raise ValueError(
"Unknown built-in message function: %s" % message_func)
return message_func
...@@ -6,6 +6,9 @@ Code: https://github.com/tkipf/gcn ...@@ -6,6 +6,9 @@ Code: https://github.com/tkipf/gcn
GCN with SPMV specialization. GCN with SPMV specialization.
""" """
import torch.nn as nn import torch.nn as nn
import dgl
import dgl.function as fn
from dgl.base import ALL, is_all from dgl.base import ALL, is_all
class NodeUpdateModule(nn.Module): class NodeUpdateModule(nn.Module):
...@@ -15,13 +18,8 @@ class NodeUpdateModule(nn.Module): ...@@ -15,13 +18,8 @@ class NodeUpdateModule(nn.Module):
self.activation = activation self.activation = activation
self.attribute = None self.attribute = None
def set_attribute_to_update(self, attribute): def forward(self, node):
self.attribute = attribute h = self.linear(node['accum'])
def forward(self, node, accum, attribute=None):
if self.attribute:
accum = accum[self.attribute]
h = self.linear(accum)
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
if self.attribute: if self.attribute:
...@@ -41,9 +39,16 @@ class GCN(nn.Module): ...@@ -41,9 +39,16 @@ class GCN(nn.Module):
self.update_func = NodeUpdateModule(in_feats, out_feats, activation) self.update_func = NodeUpdateModule(in_feats, out_feats, activation)
def forward(self, g, u=ALL, v=ALL, attribute=None): def forward(self, g, u=ALL, v=ALL, attribute=None):
self.update_func.set_attribute_to_update(attribute)
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
g.update_all('from_src', 'sum', self.update_func, batchable=True) g.update_all(fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
else: else:
g.update_by_edge(u, v, 'from_src', 'sum', self.update_func, batchable=True) g.send_and_recv(u, v,
fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
g.pop_n_repr('accum')
return g return g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment