Unverified Commit dc1c0ac2 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Example][Bug] Fix GAT example (#3947)



* fix

* oops
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 338db32d
......@@ -54,7 +54,6 @@ class GAT(nn.Module):
def forward(self, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gat_layers[l](self.g, h).flatten(1)
# output projection
logits = self.gat_layers[-1](self.g, h).mean(1)
return logits
h = self.gat_layers[l](self.g, h)
h = h.flatten(1) if l != self.num_layers - 1 else h.mean(1)
return h
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment