Unverified Commit 29e66615 authored by DominikaJedynak's avatar DominikaJedynak Committed by GitHub
Browse files

[Optimization] Optimize bias term in GatConv layer (#5466)

parent a244de57
......@@ -170,21 +170,28 @@ class GATConv(nn.Module):
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
if bias:
self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_feats,))
)
else:
self.register_buffer("bias", None)
self.has_linear_res = False
self.has_explicit_bias = False
if residual:
if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False
self._in_dst_feats, num_heads * out_feats, bias=bias
)
self.has_linear_res = True
else:
self.res_fc = Identity()
else:
self.register_buffer("res_fc", None)
if bias and not self.has_linear_res:
self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_feats,))
)
self.has_explicit_bias = True
else:
self.register_buffer("bias", None)
self.reset_parameters()
self.activation = activation
......@@ -208,10 +215,12 @@ class GATConv(nn.Module):
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
if self.bias is not None:
if self.has_explicit_bias:
nn.init.constant_(self.bias, 0)
if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
if self.res_fc.bias is not None:
nn.init.constant_(self.res_fc.bias, 0)
def set_allow_zero_in_degree(self, set_value):
r"""
......@@ -344,7 +353,7 @@ class GATConv(nn.Module):
)
rst = rst + resval
# bias
if self.bias is not None:
if self.has_explicit_bias:
rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)),
self._num_heads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment