Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
29e66615
Unverified
Commit
29e66615
authored
Apr 05, 2023
by
DominikaJedynak
Committed by
GitHub
Apr 05, 2023
Browse files
[Optimization] Optimize bias term in GatConv layer (#5466)
parent
a244de57
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
9 deletions
+18
-9
python/dgl/nn/pytorch/conv/gatconv.py
python/dgl/nn/pytorch/conv/gatconv.py
+18
-9
No files found.
python/dgl/nn/pytorch/conv/gatconv.py
View file @
29e66615
...
@@ -170,21 +170,28 @@ class GATConv(nn.Module):
...
@@ -170,21 +170,28 @@ class GATConv(nn.Module):
self
.
feat_drop
=
nn
.
Dropout
(
feat_drop
)
self
.
feat_drop
=
nn
.
Dropout
(
feat_drop
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
negative_slope
)
self
.
leaky_relu
=
nn
.
LeakyReLU
(
negative_slope
)
if
bias
:
self
.
bias
=
nn
.
Parameter
(
self
.
has_linear_res
=
False
th
.
FloatTensor
(
size
=
(
num_heads
*
out_feats
,))
self
.
has_explicit_bias
=
False
)
else
:
self
.
register_buffer
(
"bias"
,
None
)
if
residual
:
if
residual
:
if
self
.
_in_dst_feats
!=
out_feats
*
num_heads
:
if
self
.
_in_dst_feats
!=
out_feats
*
num_heads
:
self
.
res_fc
=
nn
.
Linear
(
self
.
res_fc
=
nn
.
Linear
(
self
.
_in_dst_feats
,
num_heads
*
out_feats
,
bias
=
False
self
.
_in_dst_feats
,
num_heads
*
out_feats
,
bias
=
bias
)
)
self
.
has_linear_res
=
True
else
:
else
:
self
.
res_fc
=
Identity
()
self
.
res_fc
=
Identity
()
else
:
else
:
self
.
register_buffer
(
"res_fc"
,
None
)
self
.
register_buffer
(
"res_fc"
,
None
)
if
bias
and
not
self
.
has_linear_res
:
self
.
bias
=
nn
.
Parameter
(
th
.
FloatTensor
(
size
=
(
num_heads
*
out_feats
,))
)
self
.
has_explicit_bias
=
True
else
:
self
.
register_buffer
(
"bias"
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
self
.
activation
=
activation
self
.
activation
=
activation
...
@@ -208,10 +215,12 @@ class GATConv(nn.Module):
...
@@ -208,10 +215,12 @@ class GATConv(nn.Module):
nn
.
init
.
xavier_normal_
(
self
.
fc_dst
.
weight
,
gain
=
gain
)
nn
.
init
.
xavier_normal_
(
self
.
fc_dst
.
weight
,
gain
=
gain
)
nn
.
init
.
xavier_normal_
(
self
.
attn_l
,
gain
=
gain
)
nn
.
init
.
xavier_normal_
(
self
.
attn_l
,
gain
=
gain
)
nn
.
init
.
xavier_normal_
(
self
.
attn_r
,
gain
=
gain
)
nn
.
init
.
xavier_normal_
(
self
.
attn_r
,
gain
=
gain
)
if
self
.
bias
is
not
None
:
if
self
.
has_explicit_bias
:
nn
.
init
.
constant_
(
self
.
bias
,
0
)
nn
.
init
.
constant_
(
self
.
bias
,
0
)
if
isinstance
(
self
.
res_fc
,
nn
.
Linear
):
if
isinstance
(
self
.
res_fc
,
nn
.
Linear
):
nn
.
init
.
xavier_normal_
(
self
.
res_fc
.
weight
,
gain
=
gain
)
nn
.
init
.
xavier_normal_
(
self
.
res_fc
.
weight
,
gain
=
gain
)
if
self
.
res_fc
.
bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
res_fc
.
bias
,
0
)
def
set_allow_zero_in_degree
(
self
,
set_value
):
def
set_allow_zero_in_degree
(
self
,
set_value
):
r
"""
r
"""
...
@@ -344,7 +353,7 @@ class GATConv(nn.Module):
...
@@ -344,7 +353,7 @@ class GATConv(nn.Module):
)
)
rst
=
rst
+
resval
rst
=
rst
+
resval
# bias
# bias
if
self
.
bias
is
not
None
:
if
self
.
has_explicit_bias
:
rst
=
rst
+
self
.
bias
.
view
(
rst
=
rst
+
self
.
bias
.
view
(
*
((
1
,)
*
len
(
dst_prefix_shape
)),
*
((
1
,)
*
len
(
dst_prefix_shape
)),
self
.
_num_heads
,
self
.
_num_heads
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment