Unverified Commit c380f0b5 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

Update gatv2.py (#3921)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent e06e63d5
......@@ -45,7 +45,7 @@ class GATv2(nn.Module):
def forward(self, g, inputs):
h = inputs
for l in range(self.num_layers):
h = self.gatv2_layers[l](h).flatten(1)
h = self.gatv2_layers[l](g, h).flatten(1)
# output projection
logits = self.gatv2_layers[-1](h).mean(1)
logits = self.gatv2_layers[-1](g, h).mean(1)
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment