"examples/pytorch/appnp/train.py" did not exist on "cdbeb17f2d6cefa1636a175d37d6565b24ea7c1f"
Unverified Commit 61fa3c6c authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

Builtin function and API changes (#53)

* WIP: API renaming

* API rewrite and node function refactor

* builtin functions

* builtin functions tested

* fix test

* send and recv spmv test

* WIP: fix examples

* Fix examples using new APIs
parent 8c71f3f8
...@@ -30,7 +30,7 @@ class GATReduce(nn.Module): ...@@ -30,7 +30,7 @@ class GATReduce(nn.Module):
e = F.softmax(F.leaky_relu(a), dim=0) e = F.softmax(F.leaky_relu(a), dim=0)
if self.attn_drop != 0.0: if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop) e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=0) # shape (D,) return {'accum' : torch.sum(e * ft, dim=0)} # shape (D,)
class GATFinalize(nn.Module): class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual): def __init__(self, headid, indim, hiddendim, activation, residual):
...@@ -43,8 +43,8 @@ class GATFinalize(nn.Module): ...@@ -43,8 +43,8 @@ class GATFinalize(nn.Module):
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim) self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum): def forward(self, node):
ret = accum ret = node['accum']
if self.residual: if self.residual:
if self.residual_fc is not None: if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret ret = self.residual_fc(node['h']) + ret
......
...@@ -33,7 +33,7 @@ class GATReduce(nn.Module): ...@@ -33,7 +33,7 @@ class GATReduce(nn.Module):
e = F.softmax(F.leaky_relu(a), dim=1) e = F.softmax(F.leaky_relu(a), dim=1)
if self.attn_drop != 0.0: if self.attn_drop != 0.0:
e = F.dropout(e, self.attn_drop) e = F.dropout(e, self.attn_drop)
return torch.sum(e * ft, dim=1) # shape (B, D) return {'accum' : torch.sum(e * ft, dim=1)} # shape (B, D)
class GATFinalize(nn.Module): class GATFinalize(nn.Module):
def __init__(self, headid, indim, hiddendim, activation, residual): def __init__(self, headid, indim, hiddendim, activation, residual):
...@@ -46,8 +46,8 @@ class GATFinalize(nn.Module): ...@@ -46,8 +46,8 @@ class GATFinalize(nn.Module):
if indim != hiddendim: if indim != hiddendim:
self.residual_fc = nn.Linear(indim, hiddendim) self.residual_fc = nn.Linear(indim, hiddendim)
def forward(self, node, accum): def forward(self, node):
ret = accum ret = node['accum']
if self.residual: if self.residual:
if self.residual_fc is not None: if self.residual_fc is not None:
ret = self.residual_fc(node['h']) + ret ret = self.residual_fc(node['h']) + ret
......
...@@ -16,16 +16,16 @@ def gcn_msg(src, edge): ...@@ -16,16 +16,16 @@ def gcn_msg(src, edge):
return src['h'] return src['h']
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
return sum(msgs) return {'h' : sum(msgs)}
class NodeUpdateModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node, accum): def forward(self, node):
h = self.linear(accum) h = self.linear(node['h'])
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h' : h} return {'h' : h}
...@@ -43,12 +43,12 @@ class GCN(nn.Module): ...@@ -43,12 +43,12 @@ class GCN(nn.Module):
self.g = DGLGraph(nx_graph) self.g = DGLGraph(nx_graph)
self.dropout = dropout self.dropout = dropout
# input layer # input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features, train_nodes): def forward(self, features, train_nodes):
for n, feat in features.items(): for n, feat in features.items():
......
...@@ -21,14 +21,14 @@ def gcn_msg(src, edge): ...@@ -21,14 +21,14 @@ def gcn_msg(src, edge):
def gcn_reduce(node, msgs): def gcn_reduce(node, msgs):
return torch.sum(msgs, 1) return torch.sum(msgs, 1)
class NodeUpdateModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node, accum): def forward(self, node):
h = self.linear(accum) h = self.linear(node)
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return h
...@@ -46,12 +46,12 @@ class GCN(nn.Module): ...@@ -46,12 +46,12 @@ class GCN(nn.Module):
self.g = g self.g = g
self.dropout = dropout self.dropout = dropout
# input layer # input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr(features)
......
...@@ -12,17 +12,18 @@ import torch ...@@ -12,17 +12,18 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import register_data_args, load_data
class NodeUpdateModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation=None): def __init__(self, in_feats, out_feats, activation=None):
super(NodeUpdateModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node, accum): def forward(self, node):
h = self.linear(accum) h = self.linear(node)
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return h return h
...@@ -40,12 +41,12 @@ class GCN(nn.Module): ...@@ -40,12 +41,12 @@ class GCN(nn.Module):
self.g = g self.g = g
self.dropout = dropout self.dropout = dropout
# input layer # input layer
self.layers = nn.ModuleList([NodeUpdateModule(in_feats, n_hidden, activation)]) self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(NodeUpdateModule(n_hidden, n_hidden, activation)) self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
# output layer # output layer
self.layers.append(NodeUpdateModule(n_hidden, n_classes)) self.layers.append(NodeApplyModule(n_hidden, n_classes))
def forward(self, features): def forward(self, features):
self.g.set_n_repr(features) self.g.set_n_repr(features)
...@@ -54,7 +55,7 @@ class GCN(nn.Module): ...@@ -54,7 +55,7 @@ class GCN(nn.Module):
if self.dropout: if self.dropout:
val = F.dropout(self.g.get_n_repr(), p=self.dropout) val = F.dropout(self.g.get_n_repr(), p=self.dropout)
self.g.set_n_repr(val) self.g.set_n_repr(val)
self.g.update_all('from_src', 'sum', layer, batchable=True) self.g.update_all(fn.copy_src(), fn.sum(), layer, batchable=True)
return self.g.pop_n_repr() return self.g.pop_n_repr()
def main(args): def main(args):
......
...@@ -21,10 +21,10 @@ import dgl.context as ctx ...@@ -21,10 +21,10 @@ import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args): def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes() n = g.number_of_nodes()
if cuda: if cuda:
adjmat = g.cached_graph.adjmat(ctx.gpu(args.gpu)) adjmat = g.cached_graph.adjmat().get(ctx.gpu(args.gpu))
mask = th.ones((n, 1)).cuda() mask = th.ones((n, 1)).cuda()
else: else:
adjmat = g.cached_graph.adjmat(ctx.cpu()) adjmat = g.cached_graph.adjmat().get(ctx.cpu())
mask = th.ones((n, 1)) mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask) degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.: while th.sum(mask) != 0.:
...@@ -59,6 +59,9 @@ def main(args): ...@@ -59,6 +59,9 @@ def main(args):
args.dropout) args.dropout)
if cuda: if cuda:
model.cuda() model.cuda()
zero_initializer = lambda shape : th.zeros(shape).cuda()
else:
zero_initializer = th.zeros
print(model) print(model)
optimizer = optim.Adagrad(model.parameters(), optimizer = optim.Adagrad(model.parameters(),
lr=args.lr, lr=args.lr,
...@@ -68,21 +71,14 @@ def main(args): ...@@ -68,21 +71,14 @@ def main(args):
t_epoch = time.time() t_epoch = time.time()
for step, batch in enumerate(train_loader): for step, batch in enumerate(train_loader):
g = batch.graph g = batch.graph
n = g.number_of_nodes()
x = th.zeros((n, args.x_size))
h = th.zeros((n, args.h_size))
c = th.zeros((n, args.h_size))
if cuda: if cuda:
batch = _batch_to_cuda(batch) batch = _batch_to_cuda(batch)
x = x.cuda()
h = h.cuda()
c = c.cuda()
if step >= 3: if step >= 3:
t0 = time.time() t0 = time.time()
# traverse graph # traverse graph
giter = list(tensor_topo_traverse(g, False, args)) giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, x, h, c, iterator=giter, train=True) logits = model(batch, zero_initializer, iterator=giter, train=True)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label) loss = F.nll_loss(logp, batch.label)
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -52,19 +52,13 @@ class ChildSumTreeLSTMCell(nn.Module): ...@@ -52,19 +52,13 @@ class ChildSumTreeLSTMCell(nn.Module):
c_tild = th.sum(f * msgs['c'], 1) c_tild = th.sum(f * msgs['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild} return {'h_tild' : h_tild, 'c_tild' : c_tild}
def update_func(self, node, accum): def apply_func(self, node):
# equation (3), (5), (6) # equation (3), (5), (6)
if accum is None: iou = self.W_iou(node['x']) + self.U_iou(node['h_tild'])
iou = self.W_iou(node['x'])
else:
iou = self.W_iou(node['x']) + self.U_iou(accum['h_tild'])
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7) # equation (7)
if accum is None: c = i * u + node['c_tild']
c = i * u
else:
c = i * u + accum['c_tild']
# equation (8) # equation (8)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
...@@ -79,6 +73,7 @@ class TreeLSTM(nn.Module): ...@@ -79,6 +73,7 @@ class TreeLSTM(nn.Module):
cell_type='childsum'): cell_type='childsum'):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
self.h_size = h_size
# TODO(minjie): pre-trained embedding like GLoVe # TODO(minjie): pre-trained embedding like GLoVe
self.embedding = nn.Embedding(num_vocabs, x_size) self.embedding = nn.Embedding(num_vocabs, x_size)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -88,20 +83,20 @@ class TreeLSTM(nn.Module): ...@@ -88,20 +83,20 @@ class TreeLSTM(nn.Module):
else: else:
raise RuntimeError('Unknown cell type:', cell_type) raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, x, h, c, iterator=None, train=True): def forward(self, batch, zero_initializer, h=None, c=None, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
---------- ----------
batch : dgl.data.SSTBatch batch : dgl.data.SSTBatch
The data batch. The data batch.
x : Tensor zero_initializer : callable
Initial node input. Function to return zero value tensor.
h : Tensor h : Tensor, optional
Initial hidden state. Initial hidden state.
c : Tensor c : Tensor, optional
Initial cell state. Initial cell state.
iterator : graph iterator iterator : graph iterator, optional
External iterator on graph. External iterator on graph.
Returns Returns
...@@ -110,21 +105,29 @@ class TreeLSTM(nn.Module): ...@@ -110,21 +105,29 @@ class TreeLSTM(nn.Module):
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = batch.graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func, batchable=True) g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True) g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_update_func(self.cell.update_func, batchable=True) g.register_apply_node_func(self.cell.apply_func, batchable=True)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid) embeds = self.embedding(batch.wordid)
x = zero_initializer((n, self.x_size))
x = x.index_copy(0, batch.nid_with_word, embeds) x = x.index_copy(0, batch.nid_with_word, embeds)
g.set_n_repr({'x' : x, 'h' : h, 'c' : c}) if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
if c is None:
c = zero_initializer((n, self.h_size))
c_tild = zero_initializer((n, self.h_size))
g.set_n_repr({'x' : x, 'h' : h, 'c' : c, 'h_tild' : h_tild, 'c_tild' : c_tild})
# TODO(minjie): potential bottleneck # TODO(minjie): potential bottleneck
if iterator is None: if iterator is None:
for frontier in topological_traverse(g): for frontier in topological_traverse(g):
#print('frontier', frontier) #print('frontier', frontier)
g.update_to(frontier) g.pull(frontier)
else: else:
for frontier in iterator: for frontier in iterator:
g.update_to(frontier) g.pull(frontier)
# compute logits # compute logits
h = g.pop_n_repr('h') h = g.pop_n_repr('h')
h = self.dropout(h) h = self.dropout(h)
......
...@@ -63,6 +63,7 @@ zeros = th.zeros ...@@ -63,6 +63,7 @@ zeros = th.zeros
spmm = th.spmm spmm = th.spmm
sort = th.sort sort = th.sort
arange = th.arange arange = th.arange
mul = th.mul
def to_context(x, ctx): def to_context(x, ctx):
if ctx is None: if ctx is None:
......
...@@ -2,5 +2,9 @@ ...@@ -2,5 +2,9 @@
# A special argument for selecting all nodes/edges. # A special argument for selecting all nodes/edges.
ALL = "__ALL__" ALL = "__ALL__"
def is_all(arg): def is_all(arg):
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
__MSG__ = "__MSG__"
__REPR__ = "__REPR__"
"""Built-in functors."""
from __future__ import absolute_import
import dgl.backend as F
def message_from_src(src, edge):
return src
def reduce_sum(node, msgs):
if isinstance(msgs, list):
if isinstance(msgs[0], dict):
return {k : sum(m[k] for m in msgs) for k in msgs[0].keys()}
else:
return sum(msgs)
else:
return F.sum(msgs, 1)
def reduce_max(node, msgs):
if isinstance(msgs, list):
return max(msgs)
else:
return F.max(msgs, 1)
...@@ -118,8 +118,8 @@ class CachedGraph: ...@@ -118,8 +118,8 @@ class CachedGraph:
dst = utils.toindex(dst) dst = utils.toindex(dst)
return src, dst return src, dst
@utils.ctx_cached_member @utils.cached_member
def adjmat(self, ctx): def adjmat(self):
"""Return a sparse adjacency matrix. """Return a sparse adjacency matrix.
The row dimension represents the dst nodes; the column dimension The row dimension represents the dst nodes; the column dimension
...@@ -134,8 +134,7 @@ class CachedGraph: ...@@ -134,8 +134,7 @@ class CachedGraph:
n = self._graph.vcount() n = self._graph.vcount()
dat = F.ones((len(elist),)) dat = F.ones((len(elist),))
mat = F.sparse_tensor(idx, dat, [n, n]) mat = F.sparse_tensor(idx, dat, [n, n])
mat = F.to_context(mat, ctx) return utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return mat
def freeze(self): def freeze(self):
self._freeze = True self._freeze = True
......
...@@ -22,7 +22,6 @@ _urls = { ...@@ -22,7 +22,6 @@ _urls = {
class CitationGraphDataset(object): class CitationGraphDataset(object):
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.mode = mode
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name) self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path) download(_urls[name], path=self.zip_file_path)
......
...@@ -163,6 +163,10 @@ class FrameRef(MutableMapping): ...@@ -163,6 +163,10 @@ class FrameRef(MutableMapping):
def update_rows(self, query, other): def update_rows(self, query, other):
rowids = self._getrowid(query) rowids = self._getrowid(query)
for key, col in other.items(): for key, col in other.items():
if key not in self:
# add new column
tmpref = FrameRef(self._frame, rowids)
tmpref.add_column(key, col)
idx = rowids.totensor(F.get_context(self._frame[key])) idx = rowids.totensor(F.get_context(self._frame[key]))
self._frame[key] = F.scatter_row(self._frame[key], idx, col) self._frame[key] = F.scatter_row(self._frame[key], idx, col)
......
from .message import *
from .reducer import *
"""Built-in message function."""
from __future__ import absolute_import
import operator
__all__ = ["MessageFunction", "src_mul_edge", "copy_src", "copy_edge"]
class MessageFunction(object):
def __call__(self, src, edge):
raise NotImplementedError
def name(self):
raise NotImplementedError
class BundledMessageFunction(MessageFunction):
def __init__(self, fn_list):
self.fn_list = fn_list
def __call__(self, src, edge):
ret = None
for fn in self.fn_list:
msg = fn(src, edge)
if ret is None:
ret = msg
else:
try:
ret.update(msg)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" message functions. Please specify out_field"
" for the builtin message function.")
return ret
def name(self):
return "bundled"
class SrcMulEdgeMessageFunction(MessageFunction):
def __init__(self, mul_op, src_field=None, edge_field=None, out_field=None):
self.mul_op = mul_op
self.src_field = src_field
self.edge_field = edge_field
self.out_field = out_field
def __call__(self, src, edge):
if self.src_field is not None:
src = src[self.src_field]
if self.edge_field is not None:
edge = edge[self.edge_field]
ret = self.mul_op(src, edge)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "src_mul_edge"
class CopySrcMessageFunction(MessageFunction):
def __init__(self, src_field=None, out_field=None):
self.src_field = src_field
self.out_field = out_field
def __call__(self, src, edge):
if self.src_field is not None:
ret = src[self.src_field]
else:
ret = src
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "copy_src"
class CopyEdgeMessageFunction(MessageFunction):
def __init__(self, edge_field=None, out_field=None):
self.edge_field = edge_field
self.out_field = out_field
def __call__(self, src, edge):
if self.edge_field is not None:
ret = edge[self.edge_field]
else:
ret = edge
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "copy_edge"
def src_mul_edge(src=None, edge=None, out=None):
"""TODO(minjie): docstring """
return SrcMulEdgeMessageFunction(operator.mul, src, edge, out)
def copy_src(src=None, out=None):
"""TODO(minjie): docstring """
return CopySrcMessageFunction(src, out)
def copy_edge(edge=None, out=None):
"""TODO(minjie): docstring """
return CopyEdgeMessageFunction(edge, out)
"""Built-in reducer function."""
from __future__ import absolute_import
import dgl.backend as F
__all__ = ["ReduceFunction", "sum", "max"]
class ReduceFunction(object):
def __call__(self, node, msgs):
raise NotImplementedError
def name(self):
raise NotImplementedError
class BundledReduceFunction(ReduceFunction):
def __init__(self, fn_list):
self.fn_list = fn_list
def __call__(self, node, msgs):
ret = None
for fn in self.fn_list:
rpr = fn(node, msgs)
if ret is None:
ret = rpr
else:
try:
ret.update(rpr)
except e:
raise RuntimeError("Failed to merge results of two builtin"
" reduce functions. Please specify out_field"
" for the builtin reduce function.")
return ret
def name(self):
return "bundled"
class SumReducerFunction(ReduceFunction):
def __init__(self, batch_sum_op, nonbatch_sum_op, msg_field=None, out_field=None):
self.batch_sum_op = batch_sum_op
self.nonbatch_sum_op = nonbatch_sum_op
self.msg_field = msg_field
self.out_field = out_field
def __call__(self, node, msgs):
if isinstance(msgs, list):
if self.msg_field is None:
ret = self.nonbatch_sum_op(msgs)
else:
ret = self.nonbatch_sum_op([msg[self.msg_field] for msg in msgs])
else:
if self.msg_field is None:
ret = self.batch_sum_op(msgs, 1)
else:
ret = self.batch_sum_op(msgs[self.msg_field], 1)
if self.out_field is None:
return ret
else:
return {self.out_field : ret}
def name(self):
return "sum"
_python_sum = sum
def sum(msgs=None, out=None):
return SumReducerFunction(F.sum, _python_sum, msgs, out)
_python_max = max
def max(msgs=None, out=None):
return SumReducerFunction(F.max, _python_max, msgs, out)
This diff is collapsed.
...@@ -6,6 +6,9 @@ Code: https://github.com/tkipf/gcn ...@@ -6,6 +6,9 @@ Code: https://github.com/tkipf/gcn
GCN with SPMV specialization. GCN with SPMV specialization.
""" """
import torch.nn as nn import torch.nn as nn
import dgl
import dgl.function as fn
from dgl.base import ALL, is_all from dgl.base import ALL, is_all
class NodeUpdateModule(nn.Module): class NodeUpdateModule(nn.Module):
...@@ -15,13 +18,8 @@ class NodeUpdateModule(nn.Module): ...@@ -15,13 +18,8 @@ class NodeUpdateModule(nn.Module):
self.activation = activation self.activation = activation
self.attribute = None self.attribute = None
def set_attribute_to_update(self, attribute): def forward(self, node):
self.attribute = attribute h = self.linear(node['accum'])
def forward(self, node, accum, attribute=None):
if self.attribute:
accum = accum[self.attribute]
h = self.linear(accum)
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
if self.attribute: if self.attribute:
...@@ -41,9 +39,16 @@ class GCN(nn.Module): ...@@ -41,9 +39,16 @@ class GCN(nn.Module):
self.update_func = NodeUpdateModule(in_feats, out_feats, activation) self.update_func = NodeUpdateModule(in_feats, out_feats, activation)
def forward(self, g, u=ALL, v=ALL, attribute=None): def forward(self, g, u=ALL, v=ALL, attribute=None):
self.update_func.set_attribute_to_update(attribute)
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
g.update_all('from_src', 'sum', self.update_func, batchable=True) g.update_all(fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
else: else:
g.update_by_edge(u, v, 'from_src', 'sum', self.update_func, batchable=True) g.send_and_recv(u, v,
fn.copy_src(src=attribute),
fn.sum(out='accum'),
self.update_func,
batchable=True)
g.pop_n_repr('accum')
return g return g
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment