Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ea8b5d79
Commit
ea8b5d79
authored
Nov 22, 2021
by
VoVAllen
Browse files
fix #2278
parent
7b4b8129
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
10 deletions
+7
-10
python/dgl/nn/pytorch/conv/ginconv.py
python/dgl/nn/pytorch/conv/ginconv.py
+5
-9
tests/pytorch/test_nn.py
tests/pytorch/test_nn.py
+2
-1
No files found.
python/dgl/nn/pytorch/conv/ginconv.py
View file @
ea8b5d79
...
@@ -76,14 +76,9 @@ class GINConv(nn.Module):
...
@@ -76,14 +76,9 @@ class GINConv(nn.Module):
super
(
GINConv
,
self
).
__init__
()
super
(
GINConv
,
self
).
__init__
()
self
.
apply_func
=
apply_func
self
.
apply_func
=
apply_func
self
.
_aggregator_type
=
aggregator_type
self
.
_aggregator_type
=
aggregator_type
if
aggregator_type
==
'sum'
:
if
aggregator_type
not
in
(
'sum'
,
'max'
,
'mean'
):
self
.
_reducer
=
fn
.
sum
raise
KeyError
(
elif
aggregator_type
==
'max'
:
'Aggregator type {} not recognized.'
.
format
(
aggregator_type
))
self
.
_reducer
=
fn
.
max
elif
aggregator_type
==
'mean'
:
self
.
_reducer
=
fn
.
mean
else
:
raise
KeyError
(
'Aggregator type {} not recognized.'
.
format
(
aggregator_type
))
# to specify whether eps is trainable or not.
# to specify whether eps is trainable or not.
if
learn_eps
:
if
learn_eps
:
self
.
eps
=
th
.
nn
.
Parameter
(
th
.
FloatTensor
([
init_eps
]))
self
.
eps
=
th
.
nn
.
Parameter
(
th
.
FloatTensor
([
init_eps
]))
...
@@ -120,6 +115,7 @@ class GINConv(nn.Module):
...
@@ -120,6 +115,7 @@ class GINConv(nn.Module):
If ``apply_func`` is None, :math:`D_{out}` should be the same
If ``apply_func`` is None, :math:`D_{out}` should be the same
as input dimensionality.
as input dimensionality.
"""
"""
_reducer
=
getattr
(
fn
,
self
.
_aggregator_type
)
with
graph
.
local_scope
():
with
graph
.
local_scope
():
aggregate_fn
=
fn
.
copy_src
(
'h'
,
'm'
)
aggregate_fn
=
fn
.
copy_src
(
'h'
,
'm'
)
if
edge_weight
is
not
None
:
if
edge_weight
is
not
None
:
...
@@ -129,7 +125,7 @@ class GINConv(nn.Module):
...
@@ -129,7 +125,7 @@ class GINConv(nn.Module):
feat_src
,
feat_dst
=
expand_as_pair
(
feat
,
graph
)
feat_src
,
feat_dst
=
expand_as_pair
(
feat
,
graph
)
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
srcdata
[
'h'
]
=
feat_src
graph
.
update_all
(
aggregate_fn
,
self
.
_reducer
(
'm'
,
'neigh'
))
graph
.
update_all
(
aggregate_fn
,
_reducer
(
'm'
,
'neigh'
))
rst
=
(
1
+
self
.
eps
)
*
feat_dst
+
graph
.
dstdata
[
'neigh'
]
rst
=
(
1
+
self
.
eps
)
*
feat_dst
+
graph
.
dstdata
[
'neigh'
]
if
self
.
apply_func
is
not
None
:
if
self
.
apply_func
is
not
None
:
rst
=
self
.
apply_func
(
rst
)
rst
=
self
.
apply_func
(
rst
)
...
...
tests/pytorch/test_nn.py
View file @
ea8b5d79
...
@@ -779,12 +779,13 @@ def test_gin_conv(g, idtype, aggregator_type):
...
@@ -779,12 +779,13 @@ def test_gin_conv(g, idtype, aggregator_type):
th
.
nn
.
Linear
(
5
,
12
),
th
.
nn
.
Linear
(
5
,
12
),
aggregator_type
aggregator_type
)
)
th
.
save
(
gin
,
tmp_buffer
)
feat
=
F
.
randn
((
g
.
number_of_src_nodes
(),
5
))
feat
=
F
.
randn
((
g
.
number_of_src_nodes
(),
5
))
gin
=
gin
.
to
(
ctx
)
gin
=
gin
.
to
(
ctx
)
h
=
gin
(
g
,
feat
)
h
=
gin
(
g
,
feat
)
# test pickle
# test pickle
th
.
save
(
h
,
tmp_buffer
)
th
.
save
(
gin
,
tmp_buffer
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
12
)
assert
h
.
shape
==
(
g
.
number_of_dst_nodes
(),
12
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment