Unverified Commit 52d43127 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] Fix arma example (#4218)



* Fix arma example

* Update
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent e3b6ac8e
...@@ -64,7 +64,7 @@ class ARMAConv(nn.Module): ...@@ -64,7 +64,7 @@ class ARMAConv(nn.Module):
# assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees() # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
degs = g.in_degrees().float().clamp(min=1) degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1) norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
output = None output = []
for k in range(self.K): for k in range(self.K):
feats = init_feats feats = init_feats
...@@ -88,13 +88,9 @@ class ARMAConv(nn.Module): ...@@ -88,13 +88,9 @@ class ARMAConv(nn.Module):
if self.activation is not None: if self.activation is not None:
feats = self.activation(feats) feats = self.activation(feats)
output.append(feats)
if output is None:
output = feats return torch.stack(output).mean(dim=0)
else:
output += feats
return output / self.K
class ARMA4NC(nn.Module): class ARMA4NC(nn.Module):
def __init__(self, def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment