Unverified Commit c380f0b5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update gatv2.py (#3921)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent e06e63d5
...@@ -45,7 +45,7 @@ class GATv2(nn.Module): ...@@ -45,7 +45,7 @@ class GATv2(nn.Module):
def forward(self, g, inputs): def forward(self, g, inputs):
h = inputs h = inputs
for l in range(self.num_layers): for l in range(self.num_layers):
h = self.gatv2_layers[l](h).flatten(1) h = self.gatv2_layers[l](g, h).flatten(1)
# output projection # output projection
logits = self.gatv2_layers[-1](h).mean(1) logits = self.gatv2_layers[-1](g, h).mean(1)
return logits return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment