"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "62d065ca4ab15eba9b5237258145ffb2d9b2ca49"
Commit ea8b5d79 authored by VoVAllen's avatar VoVAllen
Browse files

fix #2278

parent 7b4b8129
......@@ -76,14 +76,9 @@ class GINConv(nn.Module):
super(GINConv, self).__init__()
self.apply_func = apply_func
self._aggregator_type = aggregator_type
if aggregator_type == 'sum':
self._reducer = fn.sum
elif aggregator_type == 'max':
self._reducer = fn.max
elif aggregator_type == 'mean':
self._reducer = fn.mean
else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
if aggregator_type not in ('sum', 'max', 'mean'):
raise KeyError(
'Aggregator type {} not recognized.'.format(aggregator_type))
# to specify whether eps is trainable or not.
if learn_eps:
self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
......@@ -120,6 +115,7 @@ class GINConv(nn.Module):
If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality.
"""
_reducer = getattr(fn, self._aggregator_type)
with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None:
......@@ -129,7 +125,7 @@ class GINConv(nn.Module):
feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))
graph.update_all(aggregate_fn, _reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None:
rst = self.apply_func(rst)
......
......@@ -779,12 +779,13 @@ def test_gin_conv(g, idtype, aggregator_type):
th.nn.Linear(5, 12),
aggregator_type
)
th.save(gin, tmp_buffer)
feat = F.randn((g.number_of_src_nodes(), 5))
gin = gin.to(ctx)
h = gin(g, feat)
# test pickle
th.save(h, tmp_buffer)
th.save(gin, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), 12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment