Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
91efc915
Commit
91efc915
authored
Feb 04, 2021
by
rusty1s
Browse files
add gcn example
parent
e2d2af18
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
158 additions
and
35 deletions
+158
-35
examples/train_gcn.py
examples/train_gcn.py
+85
-0
examples/train_gin.py
examples/train_gin.py
+0
-0
torch_geometric_autoscale/loader.py
torch_geometric_autoscale/loader.py
+1
-1
torch_geometric_autoscale/metis.py
torch_geometric_autoscale/metis.py
+3
-2
torch_geometric_autoscale/models/__init__.py
torch_geometric_autoscale/models/__init__.py
+12
-12
torch_geometric_autoscale/models/base.py
torch_geometric_autoscale/models/base.py
+1
-1
torch_geometric_autoscale/models/gcn.py
torch_geometric_autoscale/models/gcn.py
+53
-16
torch_geometric_autoscale/pool.py
torch_geometric_autoscale/pool.py
+3
-3
No files found.
examples/train_gcn.py
0 → 100644
View file @
91efc915
import
argparse
import
torch
from
torch_geometric.nn.conv.gcn_conv
import
gcn_norm
from
torch_geometric_autoscale.models
import
GCN
from
torch_geometric_autoscale
import
(
get_data
,
metis
,
permute
,
SubgraphLoader
,
compute_acc
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--root'
,
type
=
str
,
required
=
True
,
help
=
'Root directory of dataset storage.'
)
parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
12345
)
device
=
f
'cuda:
{
args
.
device
}
'
if
torch
.
cuda
.
is_available
()
else
'cpu'
data
,
in_channels
,
out_channels
=
get_data
(
args
.
root
,
name
=
'cora'
)
# Pre-process adjacency matrix for GCN:
data
.
adj_t
=
gcn_norm
(
data
.
adj_t
,
add_self_loops
=
True
)
# Pre-partition the graph using Metis:
perm
,
ptr
=
metis
(
data
.
adj_t
,
num_parts
=
40
,
log
=
True
)
data
=
permute
(
data
,
perm
,
log
=
True
)
loader
=
SubgraphLoader
(
data
,
ptr
,
batch_size
=
10
,
shuffle
=
True
)
model
=
GCN
(
num_nodes
=
data
.
num_nodes
,
in_channels
=
in_channels
,
hidden_channels
=
16
,
out_channels
=
out_channels
,
num_layers
=
2
,
dropout
=
0.5
,
drop_input
=
True
,
pool_size
=
2
,
buffer_size
=
1000
,
).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
([
dict
(
params
=
model
.
reg_modules
.
parameters
(),
weight_decay
=
5e-4
),
dict
(
params
=
model
.
nonreg_modules
.
parameters
(),
weight_decay
=
0
)
],
lr
=
0.01
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
train
(
data
,
model
,
loader
,
optimizer
):
model
.
train
()
for
batch
,
batch_size
,
n_id
,
offset
,
count
in
loader
:
batch
=
batch
.
to
(
device
)
train_mask
=
batch
.
train_mask
[:
batch_size
]
optimizer
.
zero_grad
()
out
=
model
(
batch
.
x
,
batch
.
adj_t
,
batch_size
,
n_id
,
offset
,
count
)
loss
=
criterion
(
out
[
train_mask
],
batch
.
y
[:
batch_size
][
train_mask
])
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
@
torch
.
no_grad
()
def
test
(
data
,
model
):
model
.
eval
()
out
=
model
(
data
.
x
.
to
(
model
.
device
),
data
.
adj_t
.
to
(
model
.
device
)).
cpu
()
train_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
train_mask
)
val_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
val_mask
)
test_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
test_mask
)
return
train_acc
,
val_acc
,
test_acc
test
(
data
,
model
)
# Fill history.
best_val_acc
=
test_acc
=
0
for
epoch
in
range
(
1
,
201
):
train
(
data
,
model
,
loader
,
optimizer
)
train_acc
,
val_acc
,
tmp_test_acc
=
test
(
data
,
model
)
if
val_acc
>
best_val_acc
:
best_val_acc
=
val_acc
test_acc
=
tmp_test_acc
print
(
f
'Epoch:
{
epoch
:
03
d
}
, Train:
{
train_acc
:.
4
f
}
, Val:
{
val_acc
:.
4
f
}
'
f
'Test:
{
tmp_test_acc
:.
4
f
}
, Final:
{
test_acc
:.
4
f
}
'
)
examples/train_gin.py
0 → 100644
View file @
91efc915
torch_geometric_autoscale/loader.py
View file @
91efc915
...
@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
...
@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
from
torch_sparse
import
SparseTensor
from
torch_sparse
import
SparseTensor
from
torch_geometric.data
import
Data
from
torch_geometric.data
import
Data
relabel_fn
=
torch
.
ops
.
scaling_gnns
.
relabel_one_hop
relabel_fn
=
torch
.
ops
.
torch_geometric_autoscale
.
relabel_one_hop
class
SubData
(
NamedTuple
):
class
SubData
(
NamedTuple
):
...
...
torch_geometric_autoscale/metis.py
View file @
91efc915
...
@@ -7,6 +7,8 @@ from torch import Tensor
...
@@ -7,6 +7,8 @@ from torch import Tensor
from
torch_sparse
import
SparseTensor
from
torch_sparse
import
SparseTensor
from
torch_geometric.data
import
Data
from
torch_geometric.data
import
Data
partition_fn
=
torch
.
ops
.
torch_sparse
.
partition
def
metis
(
adj_t
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
def
metis
(
adj_t
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
log
:
bool
=
True
)
->
Tuple
[
Tensor
,
Tensor
]:
log
:
bool
=
True
)
->
Tuple
[
Tensor
,
Tensor
]:
...
@@ -22,8 +24,7 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
...
@@ -22,8 +24,7 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
perm
,
ptr
=
torch
.
arange
(
num_nodes
),
torch
.
tensor
([
0
,
num_nodes
])
perm
,
ptr
=
torch
.
arange
(
num_nodes
),
torch
.
tensor
([
0
,
num_nodes
])
else
:
else
:
rowptr
,
col
,
_
=
adj_t
.
csr
()
rowptr
,
col
,
_
=
adj_t
.
csr
()
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
None
,
cluster
=
partition_fn
(
rowptr
,
col
,
None
,
num_parts
,
recursive
)
num_parts
,
recursive
)
cluster
,
perm
=
cluster
.
sort
()
cluster
,
perm
=
cluster
.
sort
()
ptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
cluster
,
num_parts
)
ptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
cluster
,
num_parts
)
...
...
torch_geometric_autoscale/models/__init__.py
View file @
91efc915
from
.base
import
ScalableGNN
from
.base
import
ScalableGNN
from
.gcn
import
GCN
from
.gcn
import
GCN
from
.gat
import
GAT
#
from .gat import GAT
from
.appnp
import
APPNP
#
from .appnp import APPNP
from
.gcn2
import
GCN2
#
from .gcn2 import GCN2
from
.gin
import
GIN
#
from .gin import GIN
from
.pna
import
PNA
#
from .pna import PNA
from
.pna_jk
import
PNA_JK
#
from .pna_jk import PNA_JK
__all__
=
[
__all__
=
[
'ScalableGNN'
,
'ScalableGNN'
,
'GCN'
,
'GCN'
,
'GAT'
,
#
'GAT',
'APPNP'
,
#
'APPNP',
'GCN2'
,
#
'GCN2',
'GIN'
,
#
'GIN',
'PNA'
,
#
'PNA',
'PNA_JK'
,
#
'PNA_JK',
]
]
torch_geometric_autoscale/models/base.py
View file @
91efc915
...
@@ -164,5 +164,5 @@ class ScalableGNN(torch.nn.Module):
...
@@ -164,5 +164,5 @@ class ScalableGNN(torch.nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
state
:
Dict
[
Any
])
->
Tensor
:
state
:
Dict
[
str
,
Any
])
->
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
torch_geometric_autoscale/models/gcn.py
View file @
91efc915
from
typing
import
Optional
,
Dict
,
Any
from
typing
import
Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
ModuleList
,
BatchNorm1d
from
torch.nn
import
ModuleList
,
Linear
,
BatchNorm1d
from
torch_sparse
import
SparseTensor
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
GCNConv
from
torch_geometric.nn
import
GCNConv
from
scaling_gnns
.models
.base2
import
ScalableGNN
from
torch_geometric_autoscale
.models
import
ScalableGNN
class
GCN
(
ScalableGNN
):
class
GCN
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
dropout
:
float
=
0.0
,
out_channels
:
int
,
num_layers
:
int
,
dropout
:
float
=
0.0
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
residual
:
bool
=
False
,
linear
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GCN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
super
(
GCN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
buffer_size
,
device
)
pool_size
,
buffer_size
,
device
)
...
@@ -25,29 +26,43 @@ class GCN(ScalableGNN):
...
@@ -25,29 +26,43 @@ class GCN(ScalableGNN):
self
.
drop_input
=
drop_input
self
.
drop_input
=
drop_input
self
.
batch_norm
=
batch_norm
self
.
batch_norm
=
batch_norm
self
.
residual
=
residual
self
.
residual
=
residual
self
.
linear
=
linear
if
linear
:
self
.
lins
=
ModuleList
()
self
.
lins
.
append
(
Linear
(
in_channels
,
hidden_channels
))
self
.
lins
.
append
(
Linear
(
hidden_channels
,
out_channels
))
self
.
convs
=
ModuleList
()
self
.
convs
=
ModuleList
()
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
in_dim
=
in_channels
if
i
==
0
else
hidden_channels
in_dim
=
out_dim
=
hidden_channels
out_dim
=
out_channels
if
i
==
num_layers
-
1
else
hidden_channels
if
i
==
0
and
not
linear
:
in_dim
=
in_channels
if
i
==
num_layers
-
1
and
not
linear
:
out_dim
=
out_channels
conv
=
GCNConv
(
in_dim
,
out_dim
,
normalize
=
False
)
conv
=
GCNConv
(
in_dim
,
out_dim
,
normalize
=
False
)
self
.
convs
.
append
(
conv
)
self
.
convs
.
append
(
conv
)
self
.
bns
=
ModuleList
()
self
.
bns
=
ModuleList
()
for
i
in
range
(
num_layers
-
1
):
for
i
in
range
(
num_layers
):
bn
=
BatchNorm1d
(
hidden_channels
)
bn
=
BatchNorm1d
(
hidden_channels
)
self
.
bns
.
append
(
bn
)
self
.
bns
.
append
(
bn
)
@
property
@
property
def
reg_modules
(
self
):
def
reg_modules
(
self
):
if
self
.
linear
:
return
ModuleList
(
list
(
self
.
convs
)
+
list
(
self
.
bns
))
else
:
return
ModuleList
(
list
(
self
.
convs
[:
-
1
])
+
list
(
self
.
bns
))
return
ModuleList
(
list
(
self
.
convs
[:
-
1
])
+
list
(
self
.
bns
))
@
property
@
property
def
nonreg_modules
(
self
):
def
nonreg_modules
(
self
):
return
self
.
convs
[
-
1
:]
return
self
.
lins
if
self
.
linear
else
self
.
convs
[
-
1
:]
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
super
(
GCN
,
self
).
reset_parameters
()
super
(
GCN
,
self
).
reset_parameters
()
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
for
bn
in
self
.
bns
:
...
@@ -61,6 +76,10 @@ class GCN(ScalableGNN):
...
@@ -61,6 +76,10 @@ class GCN(ScalableGNN):
if
self
.
drop_input
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
if
self
.
linear
:
x
=
self
.
lins
[
0
](
x
).
relu_
()
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
for
conv
,
bn
,
hist
in
zip
(
self
.
convs
[:
-
1
],
self
.
bns
,
self
.
histories
):
for
conv
,
bn
,
hist
in
zip
(
self
.
convs
[:
-
1
],
self
.
bns
,
self
.
histories
):
h
=
conv
(
x
,
adj_t
)
h
=
conv
(
x
,
adj_t
)
if
self
.
batch_norm
:
if
self
.
batch_norm
:
...
@@ -71,23 +90,41 @@ class GCN(ScalableGNN):
...
@@ -71,23 +90,41 @@ class GCN(ScalableGNN):
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
convs
[
-
1
](
x
,
adj_t
)
h
=
self
.
convs
[
-
1
](
x
,
adj_t
)
return
x
if
not
self
.
linear
:
return
h
if
self
.
batch_norm
:
h
=
self
.
bns
[
-
1
](
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
=
h
.
relu_
()
h
=
F
.
dropout
(
h
,
p
=
self
.
dropout
,
training
=
self
.
training
)
return
self
.
lins
[
1
](
h
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
state
:
Dict
[
Any
])
->
Tensor
:
if
layer
==
0
:
if
layer
==
0
and
self
.
drop_input
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
if
self
.
linear
:
x
=
self
.
lins
[
0
](
x
).
relu_
()
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
else
:
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
layer
](
x
,
adj_t
)
h
=
self
.
convs
[
layer
](
x
,
adj_t
)
if
layer
<
self
.
num_layers
-
1
:
if
layer
<
self
.
num_layers
-
1
or
self
.
linear
:
if
self
.
batch_norm
:
if
self
.
batch_norm
:
h
=
self
.
bns
[
layer
](
h
)
h
=
self
.
bns
[
layer
](
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
+=
x
[:
h
.
size
(
0
)]
h
=
h
.
relu_
()
h
=
h
.
relu_
()
if
self
.
linear
:
h
=
F
.
dropout
(
h
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
lins
[
1
](
h
)
return
h
return
h
torch_geometric_autoscale/pool.py
View file @
91efc915
...
@@ -4,9 +4,9 @@ import torch
...
@@ -4,9 +4,9 @@ import torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.cuda
import
Stream
from
torch.cuda
import
Stream
synchronize
=
torch
.
ops
.
scaling_gnns
.
synchronize
synchronize
=
torch
.
ops
.
torch_geometric_autoscale
.
synchronize
read_async
=
torch
.
ops
.
scaling_gnns
.
read_async
read_async
=
torch
.
ops
.
torch_geometric_autoscale
.
read_async
write_async
=
torch
.
ops
.
scaling_gnns
.
write_async
write_async
=
torch
.
ops
.
torch_geometric_autoscale
.
write_async
class
AsyncIOPool
(
torch
.
nn
.
Module
):
class
AsyncIOPool
(
torch
.
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment