Commit ea8b5d79 authored by VoVAllen's avatar VoVAllen
Browse files

fix #2278

parent 7b4b8129
...@@ -76,14 +76,9 @@ class GINConv(nn.Module): ...@@ -76,14 +76,9 @@ class GINConv(nn.Module):
super(GINConv, self).__init__() super(GINConv, self).__init__()
self.apply_func = apply_func self.apply_func = apply_func
self._aggregator_type = aggregator_type self._aggregator_type = aggregator_type
if aggregator_type == 'sum': if aggregator_type not in ('sum', 'max', 'mean'):
self._reducer = fn.sum raise KeyError(
elif aggregator_type == 'max': 'Aggregator type {} not recognized.'.format(aggregator_type))
self._reducer = fn.max
elif aggregator_type == 'mean':
self._reducer = fn.mean
else:
raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
# to specify whether eps is trainable or not. # to specify whether eps is trainable or not.
if learn_eps: if learn_eps:
self.eps = th.nn.Parameter(th.FloatTensor([init_eps])) self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
...@@ -120,6 +115,7 @@ class GINConv(nn.Module): ...@@ -120,6 +115,7 @@ class GINConv(nn.Module):
If ``apply_func`` is None, :math:`D_{out}` should be the same If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality. as input dimensionality.
""" """
_reducer = getattr(fn, self._aggregator_type)
with graph.local_scope(): with graph.local_scope():
aggregate_fn = fn.copy_src('h', 'm') aggregate_fn = fn.copy_src('h', 'm')
if edge_weight is not None: if edge_weight is not None:
...@@ -129,7 +125,7 @@ class GINConv(nn.Module): ...@@ -129,7 +125,7 @@ class GINConv(nn.Module):
feat_src, feat_dst = expand_as_pair(feat, graph) feat_src, feat_dst = expand_as_pair(feat, graph)
graph.srcdata['h'] = feat_src graph.srcdata['h'] = feat_src
graph.update_all(aggregate_fn, self._reducer('m', 'neigh')) graph.update_all(aggregate_fn, _reducer('m', 'neigh'))
rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh'] rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']
if self.apply_func is not None: if self.apply_func is not None:
rst = self.apply_func(rst) rst = self.apply_func(rst)
......
...@@ -779,12 +779,13 @@ def test_gin_conv(g, idtype, aggregator_type): ...@@ -779,12 +779,13 @@ def test_gin_conv(g, idtype, aggregator_type):
th.nn.Linear(5, 12), th.nn.Linear(5, 12),
aggregator_type aggregator_type
) )
th.save(gin, tmp_buffer)
feat = F.randn((g.number_of_src_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
gin = gin.to(ctx) gin = gin.to(ctx)
h = gin(g, feat) h = gin(g, feat)
# test pickle # test pickle
th.save(h, tmp_buffer) th.save(gin, tmp_buffer)
assert h.shape == (g.number_of_dst_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment