Unverified Commit 812f5970 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update appnpconv.py (#2324)

parent 88a042d3
...@@ -88,13 +88,16 @@ class APPNPConv(nn.Module): ...@@ -88,13 +88,16 @@ class APPNPConv(nn.Module):
should be the same as input shape. should be the same as input shape.
""" """
with graph.local_scope(): with graph.local_scope():
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5) src_norm = th.pow(graph.out_degrees().float().clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1) shp = src_norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device) src_norm = th.reshape(src_norm, shp).to(feat.device)
dst_norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
shp = dst_norm.shape + (1,) * (feat.dim() - 1)
dst_norm = th.reshape(dst_norm, shp).to(feat.device)
feat_0 = feat feat_0 = feat
for _ in range(self._k): for _ in range(self._k):
# normalization by src node # normalization by src node
feat = feat * norm feat = feat * src_norm
graph.ndata['h'] = feat graph.ndata['h'] = feat
graph.edata['w'] = self.edge_drop( graph.edata['w'] = self.edge_drop(
th.ones(graph.number_of_edges(), 1).to(feat.device)) th.ones(graph.number_of_edges(), 1).to(feat.device))
...@@ -102,6 +105,6 @@ class APPNPConv(nn.Module): ...@@ -102,6 +105,6 @@ class APPNPConv(nn.Module):
fn.sum('m', 'h')) fn.sum('m', 'h'))
feat = graph.ndata.pop('h') feat = graph.ndata.pop('h')
# normalization by dst node # normalization by dst node
feat = feat * norm feat = feat * dst_norm
feat = (1 - self._alpha) * feat + self._alpha * feat_0 feat = (1 - self._alpha) * feat + self._alpha * feat_0
return feat return feat
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment