Unverified Commit 29e66615 authored by DominikaJedynak's avatar DominikaJedynak Committed by GitHub
Browse files

[Optimization] Optimize bias term in GatConv layer (#5466)

parent a244de57
...@@ -170,21 +170,28 @@ class GATConv(nn.Module): ...@@ -170,21 +170,28 @@ class GATConv(nn.Module):
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
if bias:
self.bias = nn.Parameter( self.has_linear_res = False
th.FloatTensor(size=(num_heads * out_feats,)) self.has_explicit_bias = False
)
else:
self.register_buffer("bias", None)
if residual: if residual:
if self._in_dst_feats != out_feats * num_heads: if self._in_dst_feats != out_feats * num_heads:
self.res_fc = nn.Linear( self.res_fc = nn.Linear(
self._in_dst_feats, num_heads * out_feats, bias=False self._in_dst_feats, num_heads * out_feats, bias=bias
) )
self.has_linear_res = True
else: else:
self.res_fc = Identity() self.res_fc = Identity()
else: else:
self.register_buffer("res_fc", None) self.register_buffer("res_fc", None)
if bias and not self.has_linear_res:
self.bias = nn.Parameter(
th.FloatTensor(size=(num_heads * out_feats,))
)
self.has_explicit_bias = True
else:
self.register_buffer("bias", None)
self.reset_parameters() self.reset_parameters()
self.activation = activation self.activation = activation
...@@ -208,10 +215,12 @@ class GATConv(nn.Module): ...@@ -208,10 +215,12 @@ class GATConv(nn.Module):
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain) nn.init.xavier_normal_(self.attn_r, gain=gain)
if self.bias is not None: if self.has_explicit_bias:
nn.init.constant_(self.bias, 0) nn.init.constant_(self.bias, 0)
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain) nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
if self.res_fc.bias is not None:
nn.init.constant_(self.res_fc.bias, 0)
def set_allow_zero_in_degree(self, set_value): def set_allow_zero_in_degree(self, set_value):
r""" r"""
...@@ -344,7 +353,7 @@ class GATConv(nn.Module): ...@@ -344,7 +353,7 @@ class GATConv(nn.Module):
) )
rst = rst + resval rst = rst + resval
# bias # bias
if self.bias is not None: if self.has_explicit_bias:
rst = rst + self.bias.view( rst = rst + self.bias.view(
*((1,) * len(dst_prefix_shape)), *((1,) * len(dst_prefix_shape)),
self._num_heads, self._num_heads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment