Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
38ca4fb2
Commit
38ca4fb2
authored
Jun 08, 2021
by
rusty1s
Browse files
update super calls
parent
07932207
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
49 additions
and
104 deletions
+49
-104
small_benchmark/conf/model/gat.yaml
small_benchmark/conf/model/gat.yaml
+0
-6
torch_geometric_autoscale/models/appnp.py
torch_geometric_autoscale/models/appnp.py
+5
-9
torch_geometric_autoscale/models/base.py
torch_geometric_autoscale/models/base.py
+2
-2
torch_geometric_autoscale/models/gat.py
torch_geometric_autoscale/models/gat.py
+18
-48
torch_geometric_autoscale/models/gcn.py
torch_geometric_autoscale/models/gcn.py
+5
-9
torch_geometric_autoscale/models/gcn2.py
torch_geometric_autoscale/models/gcn2.py
+5
-9
torch_geometric_autoscale/models/pna.py
torch_geometric_autoscale/models/pna.py
+6
-11
torch_geometric_autoscale/models/pna_jk.py
torch_geometric_autoscale/models/pna_jk.py
+7
-9
torch_geometric_autoscale/pool.py
torch_geometric_autoscale/pool.py
+1
-1
No files found.
small_benchmark/conf/model/gat.yaml
View file @
38ca4fb2
...
@@ -10,7 +10,6 @@ params:
...
@@ -10,7 +10,6 @@ params:
hidden_channels
:
8
hidden_channels
:
8
hidden_heads
:
8
hidden_heads
:
8
out_heads
:
1
out_heads
:
1
residual
:
false
dropout
:
0.6
dropout
:
0.6
num_parts
:
40
num_parts
:
40
batch_size
:
10
batch_size
:
10
...
@@ -28,7 +27,6 @@ params:
...
@@ -28,7 +27,6 @@ params:
hidden_channels
:
8
hidden_channels
:
8
hidden_heads
:
8
hidden_heads
:
8
out_heads
:
1
out_heads
:
1
residual
:
false
dropout
:
0.6
dropout
:
0.6
num_parts
:
24
num_parts
:
24
batch_size
:
8
batch_size
:
8
...
@@ -46,7 +44,6 @@ params:
...
@@ -46,7 +44,6 @@ params:
hidden_channels
:
8
hidden_channels
:
8
hidden_heads
:
8
hidden_heads
:
8
out_heads
:
8
out_heads
:
8
residual
:
false
dropout
:
0.6
dropout
:
0.6
num_parts
:
4
num_parts
:
4
batch_size
:
1
batch_size
:
1
...
@@ -64,7 +61,6 @@ params:
...
@@ -64,7 +61,6 @@ params:
hidden_channels
:
8
hidden_channels
:
8
hidden_heads
:
8
hidden_heads
:
8
out_heads
:
1
out_heads
:
1
residual
:
false
dropout
:
0.6
dropout
:
0.6
num_parts
:
8
num_parts
:
8
batch_size
:
2
batch_size
:
2
...
@@ -82,7 +78,6 @@ params:
...
@@ -82,7 +78,6 @@ params:
hidden_channels
:
8
hidden_channels
:
8
hidden_heads
:
8
hidden_heads
:
8
out_heads
:
1
out_heads
:
1
residual
:
false
dropout
:
0.6
dropout
:
0.6
num_parts
:
4
num_parts
:
4
batch_size
:
1
batch_size
:
1
...
@@ -100,7 +95,6 @@ params:
...
@@ -100,7 +95,6 @@ params:
hidden_channels
:
14
hidden_channels
:
14
hidden_heads
:
5
hidden_heads
:
5
out_heads
:
1
out_heads
:
1
residual
:
false
dropout
:
0.5
dropout
:
0.5
num_parts
:
2
num_parts
:
2
batch_size
:
1
batch_size
:
1
...
...
torch_geometric_autoscale/models/appnp.py
View file @
38ca4fb2
...
@@ -14,8 +14,8 @@ class APPNP(ScalableGNN):
...
@@ -14,8 +14,8 @@ class APPNP(ScalableGNN):
out_channels
:
int
,
num_layers
:
int
,
alpha
:
float
,
out_channels
:
int
,
num_layers
:
int
,
alpha
:
float
,
dropout
:
float
=
0.0
,
pool_size
:
Optional
[
int
]
=
None
,
dropout
:
float
=
0.0
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
APPNP
,
self
).
__init__
(
num_nodes
,
out_channels
,
num_layers
,
super
().
__init__
(
num_nodes
,
out_channels
,
num_layers
,
pool_size
,
pool_size
,
buffer_size
,
device
)
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -30,15 +30,11 @@ class APPNP(ScalableGNN):
...
@@ -30,15 +30,11 @@ class APPNP(ScalableGNN):
self
.
nonreg_modules
=
self
.
lins
[
1
:]
self
.
nonreg_modules
=
self
.
lins
[
1
:]
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
super
(
APPNP
,
self
).
reset_parameters
()
super
().
reset_parameters
()
for
lin
in
self
.
lins
:
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
lin
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
*
args
)
->
Tensor
:
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
0
](
x
)
x
=
self
.
lins
[
0
](
x
)
x
=
x
.
relu
()
x
=
x
.
relu
()
...
@@ -48,7 +44,7 @@ class APPNP(ScalableGNN):
...
@@ -48,7 +44,7 @@ class APPNP(ScalableGNN):
for
history
in
self
.
histories
:
for
history
in
self
.
histories
:
x
=
(
1
-
self
.
alpha
)
*
(
adj_t
@
x
)
+
self
.
alpha
*
x_0
x
=
(
1
-
self
.
alpha
)
*
(
adj_t
@
x
)
+
self
.
alpha
*
x_0
x
=
self
.
push_and_pull
(
history
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
self
.
push_and_pull
(
history
,
x
,
*
args
)
x
=
(
1
-
self
.
alpha
)
*
(
adj_t
@
x
)
+
self
.
alpha
*
x_0
x
=
(
1
-
self
.
alpha
)
*
(
adj_t
@
x
)
+
self
.
alpha
*
x_0
return
x
return
x
...
...
torch_geometric_autoscale/models/base.py
View file @
38ca4fb2
...
@@ -14,7 +14,7 @@ class ScalableGNN(torch.nn.Module):
...
@@ -14,7 +14,7 @@ class ScalableGNN(torch.nn.Module):
def
__init__
(
self
,
num_nodes
:
int
,
hidden_channels
:
int
,
num_layers
:
int
,
def
__init__
(
self
,
num_nodes
:
int
,
hidden_channels
:
int
,
num_layers
:
int
,
pool_size
:
Optional
[
int
]
=
None
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
ScalableGNN
,
self
).
__init__
()
super
().
__init__
()
self
.
num_nodes
=
num_nodes
self
.
num_nodes
=
num_nodes
self
.
hidden_channels
=
hidden_channels
self
.
hidden_channels
=
hidden_channels
...
@@ -40,7 +40,7 @@ class ScalableGNN(torch.nn.Module):
...
@@ -40,7 +40,7 @@ class ScalableGNN(torch.nn.Module):
return
self
.
histories
[
0
].
_device
return
self
.
histories
[
0
].
_device
def
_apply
(
self
,
fn
:
Callable
)
->
None
:
def
_apply
(
self
,
fn
:
Callable
)
->
None
:
super
(
ScalableGNN
,
self
).
_apply
(
fn
)
super
().
_apply
(
fn
)
# We only initialize the AsyncIOPool in case histories are on CPU:
# We only initialize the AsyncIOPool in case histories are on CPU:
if
(
str
(
self
.
emb_device
)
==
'cpu'
and
str
(
self
.
device
)[:
4
]
==
'cuda'
if
(
str
(
self
.
emb_device
)
==
'cpu'
and
str
(
self
.
device
)[:
4
]
==
'cuda'
and
self
.
pool_size
is
not
None
and
self
.
pool_size
is
not
None
...
...
torch_geometric_autoscale/models/gat.py
View file @
38ca4fb2
...
@@ -3,7 +3,7 @@ from typing import Optional
...
@@ -3,7 +3,7 @@ from typing import Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
Linear
,
ModuleList
from
torch.nn
import
ModuleList
from
torch_sparse
import
SparseTensor
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
GATConv
from
torch_geometric.nn
import
GATConv
...
@@ -13,17 +13,16 @@ from torch_geometric_autoscale.models import ScalableGNN
...
@@ -13,17 +13,16 @@ from torch_geometric_autoscale.models import ScalableGNN
class
GAT
(
ScalableGNN
):
class
GAT
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
hidden_heads
:
int
,
out_channels
:
int
,
out_heads
:
int
,
hidden_heads
:
int
,
out_channels
:
int
,
out_heads
:
int
,
num_layers
:
int
,
residual
:
bool
=
False
,
dropout
:
float
=
0.0
,
num_layers
:
int
,
dropout
:
float
=
0.0
,
pool_size
:
Optional
[
int
]
=
None
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GAT
,
self
).
__init__
(
num_nodes
,
hidden_channels
*
hidden_heads
,
super
().
__init__
(
num_nodes
,
hidden_channels
*
hidden_heads
,
num_layers
,
num_layers
,
pool_size
,
buffer_size
,
device
)
pool_size
,
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
hidden_heads
=
hidden_heads
self
.
hidden_heads
=
hidden_heads
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
self
.
out_heads
=
out_heads
self
.
out_heads
=
out_heads
self
.
residual
=
residual
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
convs
=
ModuleList
()
self
.
convs
=
ModuleList
()
...
@@ -37,62 +36,33 @@ class GAT(ScalableGNN):
...
@@ -37,62 +36,33 @@ class GAT(ScalableGNN):
concat
=
False
,
dropout
=
dropout
,
add_self_loops
=
False
)
concat
=
False
,
dropout
=
dropout
,
add_self_loops
=
False
)
self
.
convs
.
append
(
conv
)
self
.
convs
.
append
(
conv
)
self
.
lins
=
ModuleList
()
self
.
reg_modules
=
self
.
convs
if
residual
:
self
.
lins
.
append
(
Linear
(
in_channels
,
hidden_channels
*
hidden_heads
))
self
.
lins
.
append
(
Linear
(
hidden_channels
*
hidden_heads
,
out_channels
))
self
.
reg_modules
=
ModuleList
([
self
.
convs
,
self
.
lins
])
self
.
nonreg_modules
=
ModuleList
()
self
.
nonreg_modules
=
ModuleList
()
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
super
(
GAT
,
self
).
reset_parameters
()
super
().
reset_parameters
()
for
conv
in
self
.
convs
:
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
conv
.
reset_parameters
()
for
lin
in
self
.
lins
:
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
lin
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
*
args
)
->
Tensor
:
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
for
conv
,
history
in
zip
(
self
.
convs
[:
-
1
],
self
.
histories
):
for
conv
,
history
in
zip
(
self
.
convs
[:
-
1
],
self
.
histories
):
h
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
conv
((
h
,
h
[:
adj_t
.
size
(
0
)]),
adj_t
)
if
self
.
residual
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
+=
x
if
h
.
size
(
-
1
)
==
x
.
size
(
-
1
)
else
self
.
lins
[
0
](
x
)
x
=
F
.
elu
(
h
)
x
=
self
.
push_and_pull
(
history
,
x
,
batch_size
,
n_id
,
offset
,
count
)
h
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
-
1
]((
h
,
h
[:
adj_t
.
size
(
0
)]),
adj_t
)
if
self
.
residual
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
+=
self
.
lins
[
1
](
x
)
x
=
conv
((
x
,
x
[:
adj_t
.
size
(
0
)]),
adj_t
)
return
h
x
=
F
.
elu
(
x
)
x
=
self
.
push_and_pull
(
history
,
x
,
*
args
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
convs
[
-
1
]((
x
,
x
[:
adj_t
.
size
(
0
)]),
adj_t
)
return
x
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
h
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
layer
]((
h
,
h
[:
adj_t
.
size
(
0
)]),
adj_t
)
x
=
self
.
convs
[
layer
]((
x
,
x
[:
adj_t
.
size
(
0
)]),
adj_t
)
if
layer
==
0
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
0
](
x
)
if
layer
==
self
.
num_layers
-
1
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
1
](
x
)
if
self
.
residual
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
+=
x
if
layer
<
self
.
num_layers
-
1
:
if
layer
<
self
.
num_layers
-
1
:
h
=
h
.
elu
()
x
=
x
.
elu
()
return
h
return
x
torch_geometric_autoscale/models/gcn.py
View file @
38ca4fb2
...
@@ -17,8 +17,8 @@ class GCN(ScalableGNN):
...
@@ -17,8 +17,8 @@ class GCN(ScalableGNN):
residual
:
bool
=
False
,
linear
:
bool
=
False
,
residual
:
bool
=
False
,
linear
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GCN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
super
().
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
pool_size
,
buffer_size
,
device
)
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -60,7 +60,7 @@ class GCN(ScalableGNN):
...
@@ -60,7 +60,7 @@ class GCN(ScalableGNN):
return
self
.
lins
if
self
.
linear
else
self
.
convs
[
-
1
:]
return
self
.
lins
if
self
.
linear
else
self
.
convs
[
-
1
:]
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
super
(
GCN
,
self
).
reset_parameters
()
super
().
reset_parameters
()
for
lin
in
self
.
lins
:
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
for
conv
in
self
.
convs
:
...
@@ -68,11 +68,7 @@ class GCN(ScalableGNN):
...
@@ -68,11 +68,7 @@ class GCN(ScalableGNN):
for
bn
in
self
.
bns
:
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
*
args
)
->
Tensor
:
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
@@ -87,7 +83,7 @@ class GCN(ScalableGNN):
...
@@ -87,7 +83,7 @@ class GCN(ScalableGNN):
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
h
.
relu_
()
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
self
.
push_and_pull
(
hist
,
x
,
*
args
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
-
1
](
x
,
adj_t
)
h
=
self
.
convs
[
-
1
](
x
,
adj_t
)
...
...
torch_geometric_autoscale/models/gcn2.py
View file @
38ca4fb2
...
@@ -18,8 +18,8 @@ class GCN2(ScalableGNN):
...
@@ -18,8 +18,8 @@ class GCN2(ScalableGNN):
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GCN2
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
super
().
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
pool_size
,
buffer_size
,
device
)
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -53,7 +53,7 @@ class GCN2(ScalableGNN):
...
@@ -53,7 +53,7 @@ class GCN2(ScalableGNN):
return
self
.
lins
return
self
.
lins
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
super
(
GCN2
,
self
).
reset_parameters
()
super
().
reset_parameters
()
for
lin
in
self
.
lins
:
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
for
conv
in
self
.
convs
:
...
@@ -61,11 +61,7 @@ class GCN2(ScalableGNN):
...
@@ -61,11 +61,7 @@ class GCN2(ScalableGNN):
for
bn
in
self
.
bns
:
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
*
args
)
->
Tensor
:
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
@@ -80,7 +76,7 @@ class GCN2(ScalableGNN):
...
@@ -80,7 +76,7 @@ class GCN2(ScalableGNN):
if
self
.
residual
:
if
self
.
residual
:
h
+=
x
[:
h
.
size
(
0
)]
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
h
.
relu_
()
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
self
.
push_and_pull
(
hist
,
x
,
*
args
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
-
1
](
x
,
x_0
,
adj_t
)
h
=
self
.
convs
[
-
1
](
x
,
x_0
,
adj_t
)
...
...
torch_geometric_autoscale/models/pna.py
View file @
38ca4fb2
...
@@ -17,8 +17,7 @@ class PNAConv(MessagePassing):
...
@@ -17,8 +17,7 @@ class PNAConv(MessagePassing):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
aggregators
:
List
[
str
],
scalers
:
List
[
str
],
deg
:
Tensor
,
aggregators
:
List
[
str
],
scalers
:
List
[
str
],
deg
:
Tensor
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
aggr
=
None
,
**
kwargs
)
super
(
PNAConv
,
self
).
__init__
(
aggr
=
None
,
**
kwargs
)
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -83,8 +82,8 @@ class PNA(ScalableGNN):
...
@@ -83,8 +82,8 @@ class PNA(ScalableGNN):
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
PNA
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
super
().
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
pool_size
,
buffer_size
,
device
)
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -115,17 +114,13 @@ class PNA(ScalableGNN):
...
@@ -115,17 +114,13 @@ class PNA(ScalableGNN):
return
self
.
convs
[
-
1
:]
return
self
.
convs
[
-
1
:]
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
super
(
PNA
,
self
).
reset_parameters
()
super
().
reset_parameters
()
for
conv
in
self
.
convs
:
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
*
args
)
->
Tensor
:
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
@@ -136,7 +131,7 @@ class PNA(ScalableGNN):
...
@@ -136,7 +131,7 @@ class PNA(ScalableGNN):
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
h
.
relu_
()
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
self
.
push_and_pull
(
hist
,
x
,
*
args
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
convs
[
-
1
](
x
,
adj_t
)
x
=
self
.
convs
[
-
1
](
x
,
adj_t
)
...
...
torch_geometric_autoscale/models/pna_jk.py
View file @
38ca4fb2
...
@@ -18,8 +18,8 @@ class PNA_JK(ScalableGNN):
...
@@ -18,8 +18,8 @@ class PNA_JK(ScalableGNN):
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
PNA_JK
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
super
().
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
pool_size
,
buffer_size
,
device
)
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
out_channels
=
out_channels
...
@@ -59,7 +59,7 @@ class PNA_JK(ScalableGNN):
...
@@ -59,7 +59,7 @@ class PNA_JK(ScalableGNN):
return
self
.
lins
return
self
.
lins
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
super
(
PNA_JK
,
self
).
reset_parameters
()
super
().
reset_parameters
()
for
lin
in
self
.
lins
:
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
for
conv
in
self
.
convs
:
...
@@ -67,11 +67,7 @@ class PNA_JK(ScalableGNN):
...
@@ -67,11 +67,7 @@ class PNA_JK(ScalableGNN):
for
bn
in
self
.
bns
:
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
*
args
)
->
Tensor
:
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
@@ -87,7 +83,7 @@ class PNA_JK(ScalableGNN):
...
@@ -87,7 +83,7 @@ class PNA_JK(ScalableGNN):
h
+=
x
[:
h
.
size
(
0
)]
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
h
.
relu_
()
xs
+=
[
x
]
xs
+=
[
x
]
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
self
.
push_and_pull
(
hist
,
x
,
*
args
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
-
1
](
x
,
adj_t
)
h
=
self
.
convs
[
-
1
](
x
,
adj_t
)
...
@@ -104,6 +100,8 @@ class PNA_JK(ScalableGNN):
...
@@ -104,6 +100,8 @@ class PNA_JK(ScalableGNN):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
# We keep the skip connections in GPU memory for now. If one encounters
# GPU memory problems, it is advised to push `state['xs']` to the CPU.
if
layer
==
0
:
if
layer
==
0
:
if
self
.
drop_input
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
...
torch_geometric_autoscale/pool.py
View file @
38ca4fb2
...
@@ -11,7 +11,7 @@ write_async = torch.ops.torch_geometric_autoscale.write_async
...
@@ -11,7 +11,7 @@ write_async = torch.ops.torch_geometric_autoscale.write_async
class
AsyncIOPool
(
torch
.
nn
.
Module
):
class
AsyncIOPool
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
pool_size
:
int
,
buffer_size
:
int
,
embedding_dim
:
int
):
def
__init__
(
self
,
pool_size
:
int
,
buffer_size
:
int
,
embedding_dim
:
int
):
super
(
AsyncIOPool
,
self
).
__init__
()
super
().
__init__
()
self
.
pool_size
=
pool_size
self
.
pool_size
=
pool_size
self
.
buffer_size
=
buffer_size
self
.
buffer_size
=
buffer_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment