Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
2f25da6c
Commit
2f25da6c
authored
Feb 02, 2021
by
rusty1s
Browse files
initial commit
parent
ac165af3
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
751 additions
and
0 deletions
+751
-0
torch_geometric_autoscale/models/gcn.py
torch_geometric_autoscale/models/gcn.py
+92
-0
torch_geometric_autoscale/models/gcn2.py
torch_geometric_autoscale/models/gcn2.py
+117
-0
torch_geometric_autoscale/models/gin.py
torch_geometric_autoscale/models/gin.py
+83
-0
torch_geometric_autoscale/models/pna.py
torch_geometric_autoscale/models/pna.py
+159
-0
torch_geometric_autoscale/models/pna_jk.py
torch_geometric_autoscale/models/pna_jk.py
+128
-0
torch_geometric_autoscale/pool.py
torch_geometric_autoscale/pool.py
+121
-0
torch_geometric_autoscale/utils.py
torch_geometric_autoscale/utils.py
+51
-0
No files found.
torch_geometric_autoscale/models/gcn.py
0 → 100644
View file @
2f25da6c
from
typing
import
Optional
import
torch
from
torch
import
Tensor
import
torch.nn.functional
as
F
from
torch.nn
import
ModuleList
,
BatchNorm1d
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
GCNConv
from
scaling_gnns.models.base2
import
ScalableGNN
class
GCN
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
dropout
:
float
=
0.0
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GCN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
dropout
=
dropout
self
.
drop_input
=
drop_input
self
.
batch_norm
=
batch_norm
self
.
residual
=
residual
self
.
convs
=
ModuleList
()
for
i
in
range
(
num_layers
):
in_dim
=
in_channels
if
i
==
0
else
hidden_channels
out_dim
=
out_channels
if
i
==
num_layers
-
1
else
hidden_channels
conv
=
GCNConv
(
in_dim
,
out_dim
,
normalize
=
False
)
self
.
convs
.
append
(
conv
)
self
.
bns
=
ModuleList
()
for
i
in
range
(
num_layers
-
1
):
bn
=
BatchNorm1d
(
hidden_channels
)
self
.
bns
.
append
(
bn
)
@
property
def
reg_modules
(
self
):
return
ModuleList
(
list
(
self
.
convs
[:
-
1
])
+
list
(
self
.
bns
))
@
property
def
nonreg_modules
(
self
):
return
self
.
convs
[
-
1
:]
def
reset_parameters
(
self
):
super
(
GCN
,
self
).
reset_parameters
()
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
for
conv
,
bn
,
hist
in
zip
(
self
.
convs
[:
-
1
],
self
.
bns
,
self
.
histories
):
h
=
conv
(
x
,
adj_t
)
if
self
.
batch_norm
:
h
=
bn
(
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
convs
[
-
1
](
x
,
adj_t
)
return
x
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
if
layer
==
0
and
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
layer
](
x
,
adj_t
)
if
layer
<
self
.
num_layers
-
1
:
if
self
.
batch_norm
:
h
=
self
.
bns
[
layer
](
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
=
h
.
relu_
()
return
h
torch_geometric_autoscale/models/gcn2.py
0 → 100644
View file @
2f25da6c
from
typing
import
Optional
import
torch
from
torch
import
Tensor
import
torch.nn.functional
as
F
from
torch.nn
import
ModuleList
,
Linear
,
BatchNorm1d
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
GCN2Conv
from
scaling_gnns.models.base2
import
ScalableGNN
class
GCN2
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
alpha
:
float
,
theta
:
float
,
shared_weights
:
bool
=
True
,
dropout
:
float
=
0.0
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GCN2
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
dropout
=
dropout
self
.
drop_input
=
drop_input
self
.
batch_norm
=
batch_norm
self
.
residual
=
residual
self
.
lins
=
ModuleList
()
self
.
lins
.
append
(
Linear
(
in_channels
,
hidden_channels
))
self
.
lins
.
append
(
Linear
(
hidden_channels
,
out_channels
))
self
.
convs
=
ModuleList
()
for
i
in
range
(
num_layers
):
conv
=
GCN2Conv
(
hidden_channels
,
alpha
=
alpha
,
theta
=
theta
,
layer
=
i
+
1
,
shared_weights
=
shared_weights
,
normalize
=
False
)
self
.
convs
.
append
(
conv
)
self
.
bns
=
ModuleList
()
for
i
in
range
(
num_layers
):
bn
=
BatchNorm1d
(
hidden_channels
)
self
.
bns
.
append
(
bn
)
@
property
def
reg_modules
(
self
):
return
ModuleList
(
list
(
self
.
convs
)
+
list
(
self
.
bns
))
@
property
def
nonreg_modules
(
self
):
return
self
.
lins
def
reset_parameters
(
self
):
super
(
GCN2
,
self
).
reset_parameters
()
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
x_0
=
self
.
lins
[
0
](
x
).
relu_
()
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
for
conv
,
bn
,
hist
in
zip
(
self
.
convs
[:
-
1
],
self
.
bns
[:
-
1
],
self
.
histories
):
h
=
conv
(
x
,
x_0
,
adj_t
)
if
self
.
batch_norm
:
h
=
bn
(
h
)
if
self
.
residual
:
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
-
1
](
x
,
x_0
,
adj_t
)
if
self
.
batch_norm
:
h
=
self
.
bns
[
-
1
](
h
)
if
self
.
residual
:
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
1
](
x
)
return
x
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
if
layer
==
0
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
x_0
=
self
.
lins
[
0
](
x
).
relu_
()
state
[
'x_0'
]
=
x_0
[:
adj_t
.
size
(
0
)]
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
layer
](
x
,
state
[
'x_0'
],
adj_t
)
if
self
.
batch_norm
:
h
=
self
.
bns
[
layer
](
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
if
layer
==
self
.
num_layers
-
1
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
1
](
x
)
return
x
torch_geometric_autoscale/models/gin.py
0 → 100644
View file @
2f25da6c
from
typing
import
Optional
import
torch
from
torch
import
Tensor
import
torch.nn.functional
as
F
from
torch.nn
import
ModuleList
,
Identity
from
torch.nn
import
Sequential
,
Linear
,
BatchNorm1d
,
ReLU
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
GINConv
from
torch_geometric.nn.inits
import
reset
from
.base
import
HistoryGNN
class
GIN
(
HistoryGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
residual
:
bool
=
False
,
dropout
:
float
=
0.0
,
device
=
None
,
dtype
=
None
):
super
(
GIN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
device
,
dtype
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
residual
=
residual
self
.
dropout
=
dropout
self
.
lins
=
ModuleList
()
self
.
lins
.
append
(
Linear
(
in_channels
,
hidden_channels
))
self
.
lins
.
append
(
Linear
(
hidden_channels
,
out_channels
))
self
.
convs
=
ModuleList
()
for
_
in
range
(
num_layers
):
conv
=
GINConv
(
nn
=
Identity
(),
train_eps
=
True
)
self
.
convs
.
append
(
conv
)
self
.
post_nns
=
ModuleList
()
for
i
in
range
(
num_layers
):
post_nn
=
Sequential
(
Linear
(
hidden_channels
,
hidden_channels
),
BatchNorm1d
(
hidden_channels
,
track_running_stats
=
False
),
ReLU
(
inplace
=
True
),
Linear
(
hidden_channels
,
hidden_channels
),
ReLU
(
inplace
=
True
),
)
self
.
post_nns
.
append
(
post_nn
)
def
reset_parameters
(
self
):
super
(
GIN
,
self
).
reset_parameters
()
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
for
post_nn
in
self
.
post_nns
:
reset
(
post_nn
)
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
x
=
self
.
lins
[
0
](
x
).
relu
()
for
conv
,
post_nn
,
history
in
zip
(
self
.
convs
[:
-
1
],
self
.
post_nns
[:
-
1
],
self
.
histories
):
if
batch_size
is
not
None
:
h
=
torch
.
zeros_like
(
x
)
h
[:
batch_size
]
=
post_nn
(
conv
(
x
,
adj_t
)[:
batch_size
])
else
:
h
=
post_nn
(
conv
(
x
,
adj_t
))
x
=
h
.
add_
(
x
)
if
self
.
residual
else
h
x
=
self
.
push_and_pull
(
history
,
x
,
batch_size
,
n_id
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
if
batch_size
is
not
None
:
h
=
self
.
post_nns
[
-
1
](
self
.
convs
[
-
1
](
x
,
adj_t
)[:
batch_size
])
x
=
x
[:
batch_size
]
else
:
h
=
self
.
post_nns
[
-
1
](
self
.
convs
[
-
1
](
x
,
adj_t
))
x
=
h
.
add_
(
x
)
if
self
.
residual
else
h
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
1
](
x
)
return
x
torch_geometric_autoscale/models/pna.py
0 → 100644
View file @
2f25da6c
from
itertools
import
product
from
typing
import
Optional
,
List
import
torch
from
torch
import
Tensor
import
torch.nn.functional
as
F
from
torch.nn
import
ModuleList
,
Linear
,
BatchNorm1d
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
MessagePassing
from
scaling_gnns.models.base2
import
ScalableGNN
EPS
=
1e-5
class
PNAConv
(
MessagePassing
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
aggregators
:
List
[
str
],
scalers
:
List
[
str
],
deg
:
Tensor
,
**
kwargs
):
super
(
PNAConv
,
self
).
__init__
(
aggr
=
None
,
**
kwargs
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
aggregators
=
aggregators
self
.
scalers
=
scalers
deg
=
deg
.
to
(
torch
.
float
)
self
.
avg_deg
=
{
'lin'
:
deg
.
mean
().
item
(),
'log'
:
(
deg
+
1
).
log
().
mean
().
item
(),
}
self
.
pre_lins
=
torch
.
nn
.
ModuleList
([
Linear
(
in_channels
,
out_channels
)
for
_
in
range
(
len
(
aggregators
)
*
len
(
scalers
))
])
self
.
post_lins
=
torch
.
nn
.
ModuleList
([
Linear
(
out_channels
,
out_channels
)
for
_
in
range
(
len
(
aggregators
)
*
len
(
scalers
))
])
self
.
lin
=
Linear
(
in_channels
,
out_channels
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
lin
in
self
.
pre_lins
:
lin
.
reset_parameters
()
for
lin
in
self
.
post_lins
:
lin
.
reset_parameters
()
self
.
lin
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
):
out
=
self
.
propagate
(
adj_t
,
x
=
x
)
out
+=
self
.
lin
(
x
)[:
out
.
size
(
0
)]
return
out
def
message_and_aggregate
(
self
,
adj_t
:
SparseTensor
,
x
:
Tensor
)
->
Tensor
:
deg
=
adj_t
.
storage
.
rowcount
().
to
(
x
.
dtype
).
view
(
-
1
,
1
)
out
=
0
for
(
aggr
,
scaler
),
pre_lin
,
post_lin
in
zip
(
product
(
self
.
aggregators
,
self
.
scalers
),
self
.
pre_lins
,
self
.
post_lins
):
h
=
pre_lin
(
x
).
relu_
()
h
=
adj_t
.
matmul
(
h
,
reduce
=
aggr
)
h
=
post_lin
(
h
)
if
scaler
==
'amplification'
:
h
*=
(
deg
+
1
).
log
()
/
self
.
avg_deg
[
'log'
]
elif
scaler
==
'attenuation'
:
h
*=
self
.
avg_deg
[
'log'
]
/
((
deg
+
1
).
log
()
+
EPS
)
out
+=
h
return
out
class
PNA
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
:
int
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
aggregators
:
List
[
int
],
scalers
:
List
[
int
],
deg
:
Tensor
,
dropout
:
float
=
0.0
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
PNA
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
dropout
=
dropout
self
.
drop_input
=
drop_input
self
.
batch_norm
=
batch_norm
self
.
residual
=
residual
self
.
convs
=
ModuleList
()
for
i
in
range
(
num_layers
):
in_dim
=
in_channels
if
i
==
0
else
hidden_channels
out_dim
=
out_channels
if
i
==
num_layers
-
1
else
hidden_channels
conv
=
PNAConv
(
in_dim
,
out_dim
,
aggregators
=
aggregators
,
scalers
=
scalers
,
deg
=
deg
)
self
.
convs
.
append
(
conv
)
self
.
bns
=
ModuleList
()
for
i
in
range
(
num_layers
-
1
):
bn
=
BatchNorm1d
(
hidden_channels
)
self
.
bns
.
append
(
bn
)
@
property
def
reg_modules
(
self
):
return
ModuleList
(
list
(
self
.
convs
[:
-
1
])
+
list
(
self
.
bns
))
@
property
def
nonreg_modules
(
self
):
return
self
.
convs
[
-
1
:]
def
reset_parameters
(
self
):
super
(
PNA
,
self
).
reset_parameters
()
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
for
conv
,
bn
,
hist
in
zip
(
self
.
convs
[:
-
1
],
self
.
bns
,
self
.
histories
):
h
=
conv
(
x
,
adj_t
)
if
self
.
batch_norm
:
h
=
bn
(
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
convs
[
-
1
](
x
,
adj_t
)
return
x
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
if
layer
==
0
and
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
layer
](
x
,
adj_t
)
if
layer
<
self
.
num_layers
-
1
:
if
self
.
batch_norm
:
h
=
self
.
bns
[
layer
](
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
=
h
.
relu_
()
h
=
F
.
dropout
(
h
,
p
=
self
.
dropout
,
training
=
self
.
training
)
return
h
torch_geometric_autoscale/models/pna_jk.py
0 → 100644
View file @
2f25da6c
from
typing
import
Optional
,
List
import
torch
from
torch
import
Tensor
import
torch.nn.functional
as
F
from
torch.nn
import
(
ModuleList
,
Linear
,
BatchNorm1d
,
Sequential
,
ReLU
,
Identity
)
from
torch_sparse
import
SparseTensor
from
scaling_gnns.models.base2
import
ScalableGNN
from
scaling_gnns.models.pna
import
PNAConv
class
PNA_JK
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
:
int
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
aggregators
:
List
[
int
],
scalers
:
List
[
int
],
deg
:
Tensor
,
dropout
:
float
=
0.0
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
PNA_JK
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
buffer_size
,
device
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
num_layers
==
num_layers
self
.
dropout
=
dropout
self
.
drop_input
=
drop_input
self
.
batch_norm
=
batch_norm
self
.
residual
=
residual
self
.
lins
=
ModuleList
()
self
.
lins
.
append
(
Sequential
(
Linear
(
in_channels
,
hidden_channels
),
BatchNorm1d
(
hidden_channels
)
if
batch_norm
else
Identity
(),
ReLU
(
inplace
=
True
),
))
self
.
lins
.
append
(
Linear
((
num_layers
+
1
)
*
hidden_channels
,
out_channels
))
self
.
convs
=
ModuleList
()
for
_
in
range
(
num_layers
):
conv
=
PNAConv
(
hidden_channels
,
hidden_channels
,
aggregators
=
aggregators
,
scalers
=
scalers
,
deg
=
deg
)
self
.
convs
.
append
(
conv
)
self
.
bns
=
ModuleList
()
for
_
in
range
(
num_layers
):
bn
=
BatchNorm1d
(
hidden_channels
)
self
.
bns
.
append
(
bn
)
@
property
def
reg_modules
(
self
):
return
ModuleList
(
list
(
self
.
convs
)
+
list
(
self
.
bns
))
@
property
def
nonreg_modules
(
self
):
return
self
.
lins
def
reset_parameters
(
self
):
super
(
PNA_JK
,
self
).
reset_parameters
()
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
bn
.
reset_parameters
()
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
batch_size
:
Optional
[
int
]
=
None
,
n_id
:
Optional
[
Tensor
]
=
None
,
offset
:
Optional
[
Tensor
]
=
None
,
count
:
Optional
[
Tensor
]
=
None
)
->
Tensor
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
0
](
x
)
xs
=
[
x
[:
adj_t
.
size
(
0
)]]
for
conv
,
bn
,
hist
in
zip
(
self
.
convs
[:
-
1
],
self
.
bns
[:
-
1
],
self
.
histories
):
h
=
conv
(
x
,
adj_t
)
if
self
.
batch_norm
:
h
=
bn
(
h
)
if
self
.
residual
:
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
xs
+=
[
x
]
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
-
1
](
x
,
adj_t
)
if
self
.
batch_norm
:
h
=
self
.
bns
[
-
1
](
h
)
if
self
.
residual
:
h
+=
x
[:
h
.
size
(
0
)]
x
=
h
.
relu_
()
xs
+=
[
x
]
x
=
torch
.
cat
(
xs
,
dim
=-
1
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
return
self
.
lins
[
1
](
x
)
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
if
layer
==
0
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
lins
[
0
](
x
)
state
[
'xs'
]
=
[
x
[:
adj_t
.
size
(
0
)]]
h
=
self
.
convs
[
layer
](
x
,
adj_t
)
if
self
.
batch_norm
:
h
=
self
.
bns
[
layer
](
h
)
if
self
.
residual
:
h
+=
x
[:
h
.
size
(
0
)]
h
=
h
.
relu_
()
state
[
'xs'
]
+=
[
h
]
h
=
F
.
dropout
(
h
,
p
=
self
.
dropout
,
training
=
self
.
training
)
if
layer
==
self
.
num_layers
-
1
:
h
=
torch
.
cat
(
state
[
'xs'
],
dim
=-
1
)
h
=
F
.
dropout
(
h
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
lins
[
1
](
h
)
return
h
torch_geometric_autoscale/pool.py
0 → 100644
View file @
2f25da6c
from
typing
import
Optional
,
Callable
import
torch
from
torch
import
Tensor
from
torch.cuda
import
Stream
synchronize
=
torch
.
ops
.
scaling_gnns
.
synchronize
read_async
=
torch
.
ops
.
scaling_gnns
.
read_async
write_async
=
torch
.
ops
.
scaling_gnns
.
write_async
class
AsyncIOPool
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
pool_size
:
int
,
buffer_size
:
int
,
embedding_dim
:
int
):
super
(
AsyncIOPool
,
self
).
__init__
()
self
.
pool_size
=
pool_size
self
.
embedding_dim
=
embedding_dim
self
.
buffer_size
=
buffer_size
self
.
_device
=
torch
.
device
(
'cpu'
)
self
.
_pull_queue
=
[]
self
.
_push_cache
=
[
None
]
*
pool_size
self
.
_push_streams
=
[
None
]
*
pool_size
self
.
_pull_streams
=
[
None
]
*
pool_size
self
.
_cpu_buffers
=
[
None
]
*
pool_size
self
.
_cuda_buffers
=
[
None
]
*
pool_size
self
.
_pull_index
=
-
1
self
.
_push_index
=
-
1
def
_apply
(
self
,
fn
:
Callable
)
->
None
:
self
.
_device
=
fn
(
torch
.
zeros
(
1
)).
device
return
self
def
_pull_stream
(
self
,
idx
:
int
)
->
Stream
:
if
self
.
_pull_streams
[
idx
]
is
None
:
assert
str
(
self
.
_device
)[:
4
]
==
'cuda'
self
.
_pull_streams
[
idx
]
=
torch
.
cuda
.
Stream
(
self
.
_device
)
return
self
.
_pull_streams
[
idx
]
def
_push_stream
(
self
,
idx
:
int
)
->
Stream
:
if
self
.
_push_streams
[
idx
]
is
None
:
assert
str
(
self
.
_device
)[:
4
]
==
'cuda'
self
.
_push_streams
[
idx
]
=
torch
.
cuda
.
Stream
(
self
.
_device
)
return
self
.
_push_streams
[
idx
]
def
_cpu_buffer
(
self
,
idx
:
int
)
->
Tensor
:
if
self
.
_cpu_buffers
[
idx
]
is
None
:
self
.
_cpu_buffers
[
idx
]
=
torch
.
empty
(
self
.
buffer_size
,
self
.
embedding_dim
,
pin_memory
=
True
)
return
self
.
_cpu_buffers
[
idx
]
def
_cuda_buffer
(
self
,
idx
:
int
)
->
Tensor
:
if
self
.
_cuda_buffers
[
idx
]
is
None
:
assert
str
(
self
.
_device
)[:
4
]
==
'cuda'
self
.
_cuda_buffers
[
idx
]
=
torch
.
empty
(
self
.
buffer_size
,
self
.
embedding_dim
,
device
=
self
.
_device
)
return
self
.
_cuda_buffers
[
idx
]
@
torch
.
no_grad
()
def
async_pull
(
self
,
src
:
Tensor
,
offset
:
Optional
[
Tensor
],
count
:
Optional
[
Tensor
],
index
:
Tensor
)
->
None
:
self
.
_pull_index
=
(
self
.
_pull_index
+
1
)
%
self
.
pool_size
data
=
(
self
.
_pull_index
,
src
,
offset
,
count
,
index
)
self
.
_pull_queue
.
append
(
data
)
if
len
(
self
.
_pull_queue
)
<=
self
.
pool_size
:
self
.
_async_pull
(
self
.
_pull_index
,
src
,
offset
,
count
,
index
)
@
torch
.
no_grad
()
def
_async_pull
(
self
,
idx
:
int
,
src
:
Tensor
,
offset
:
Optional
[
Tensor
],
count
:
Optional
[
Tensor
],
index
:
Tensor
)
->
None
:
with
torch
.
cuda
.
stream
(
self
.
_pull_stream
(
idx
)):
read_async
(
src
,
offset
,
count
,
index
,
self
.
_cuda_buffer
(
idx
),
self
.
_cpu_buffer
(
idx
))
@
torch
.
no_grad
()
def
synchronize_pull
(
self
)
->
Tensor
:
idx
=
self
.
_pull_queue
[
0
][
0
]
synchronize
()
torch
.
cuda
.
synchronize
(
self
.
_pull_stream
(
idx
))
return
self
.
_cuda_buffer
(
idx
)
@
torch
.
no_grad
()
def
free_pull
(
self
)
->
None
:
self
.
_pull_queue
.
pop
(
0
)
if
len
(
self
.
_pull_queue
)
>=
self
.
pool_size
:
data
=
self
.
_pull_queue
[
self
.
pool_size
-
1
]
idx
,
src
,
offset
,
count
,
index
=
data
self
.
_async_pull
(
idx
,
src
,
offset
,
count
,
index
)
if
len
(
self
.
_pull_queue
)
==
0
:
self
.
_pull_index
=
-
1
@
torch
.
no_grad
()
def
async_push
(
self
,
src
:
Tensor
,
offset
:
Tensor
,
count
:
Tensor
,
dst
:
Tensor
)
->
None
:
self
.
_push_index
=
(
self
.
_push_index
+
1
)
%
self
.
pool_size
self
.
synchronize_push
(
self
.
_push_index
)
self
.
_push_cache
[
self
.
_push_index
]
=
src
with
torch
.
cuda
.
stream
(
self
.
_push_stream
(
self
.
_push_index
)):
write_async
(
src
,
offset
,
count
,
dst
)
@
torch
.
no_grad
()
def
synchronize_push
(
self
,
idx
:
Optional
[
int
]
=
None
)
->
None
:
if
idx
is
None
:
for
idx
in
range
(
self
.
pool_size
):
self
.
synchronize_push
(
idx
)
self
.
_push_index
=
-
1
else
:
torch
.
cuda
.
synchronize
(
self
.
_push_stream
(
idx
))
self
.
_push_cache
[
idx
]
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
""""""
raise
NotImplementedError
def
__repr__
(
self
):
return
(
f
'
{
self
.
__class__
.
__name__
}
(pool_size=
{
self
.
pool_size
}
, '
f
'buffer_size=
{
self
.
buffer_size
}
, '
f
'embedding_dim=
{
self
.
embedding_dim
}
, '
f
'device=
{
self
.
_device
}
)'
)
torch_geometric_autoscale/utils.py
0 → 100644
View file @
2f25da6c
from
typing
import
Optional
import
torch
from
torch
import
Tensor
def
index2mask
(
idx
:
Tensor
,
size
:
int
)
->
Tensor
:
mask
=
torch
.
zeros
(
size
,
dtype
=
torch
.
bool
,
device
=
idx
.
device
)
mask
[
idx
]
=
True
return
mask
def
compute_acc
(
logits
:
Tensor
,
y
:
Tensor
,
mask
:
Optional
[
Tensor
]
=
None
):
if
mask
is
not
None
:
logits
,
y
=
logits
[
mask
],
y
[
mask
]
if
y
.
dim
()
==
1
:
return
int
(
logits
.
argmax
(
dim
=-
1
).
eq
(
y
).
sum
())
/
y
.
size
(
0
)
else
:
y_pred
=
logits
>
0
y_true
=
y
>
0.5
tp
=
int
((
y_true
&
y_pred
).
sum
())
fp
=
int
((
~
y_true
&
y_pred
).
sum
())
fn
=
int
((
y_true
&
~
y_pred
).
sum
())
precision
=
tp
/
(
tp
+
fp
)
recall
=
tp
/
(
tp
+
fn
)
return
2
*
(
precision
*
recall
)
/
(
precision
+
recall
)
def
gen_masks
(
y
:
Tensor
,
train_per_class
:
int
=
20
,
val_per_class
:
int
=
30
,
num_splits
:
int
=
20
):
num_classes
=
int
(
y
.
max
())
+
1
train_mask
=
torch
.
zeros
(
y
.
size
(
0
),
num_splits
,
dtype
=
torch
.
bool
)
val_mask
=
torch
.
zeros
(
y
.
size
(
0
),
num_splits
,
dtype
=
torch
.
bool
)
for
c
in
range
(
num_classes
):
idx
=
(
y
==
c
).
nonzero
(
as_tuple
=
False
).
view
(
-
1
)
perm
=
torch
.
stack
(
[
torch
.
randperm
(
idx
.
size
(
0
))
for
_
in
range
(
num_splits
)],
dim
=
1
)
idx
=
idx
[
perm
]
train_idx
=
idx
[:
train_per_class
]
train_mask
.
scatter_
(
0
,
train_idx
,
True
)
val_idx
=
idx
[
train_per_class
:
train_per_class
+
val_per_class
]
val_mask
.
scatter_
(
0
,
val_idx
,
True
)
test_mask
=
~
(
train_mask
|
val_mask
)
return
train_mask
,
val_mask
,
test_mask
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment