Unverified Commit 5f4c8793 authored by schmidt-ju's avatar schmidt-ju Committed by GitHub
Browse files

Fix output dimensions of residual connection for GATv2Conv (#3584)

Number of heads needs to be considered for the transformation of the destination features. Otherwise this will crash if residual==True.
parent a2241faf
...@@ -174,7 +174,7 @@ class GATv2Conv(nn.Module): ...@@ -174,7 +174,7 @@ class GATv2Conv(nn.Module):
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
if residual: if residual:
if self._in_dst_feats != out_feats: if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear( self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=bias) self._in_dst_feats, num_heads * out_feats, bias=bias)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment