Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
91efc915
Commit
91efc915
authored
Feb 04, 2021
by
rusty1s
Browse files
add gcn example
parent
e2d2af18
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
158 additions
and
35 deletions
+158
-35
examples/train_gcn.py
examples/train_gcn.py
+85
-0
examples/train_gin.py
examples/train_gin.py
+0
-0
torch_geometric_autoscale/loader.py
torch_geometric_autoscale/loader.py
+1
-1
torch_geometric_autoscale/metis.py
torch_geometric_autoscale/metis.py
+3
-2
torch_geometric_autoscale/models/__init__.py
torch_geometric_autoscale/models/__init__.py
+12
-12
torch_geometric_autoscale/models/base.py
torch_geometric_autoscale/models/base.py
+1
-1
torch_geometric_autoscale/models/gcn.py
torch_geometric_autoscale/models/gcn.py
+53
-16
torch_geometric_autoscale/pool.py
torch_geometric_autoscale/pool.py
+3
-3
No files found.
examples/train_gcn.py
0 → 100644
View file @
91efc915
import
argparse
import
torch
from
torch_geometric.nn.conv.gcn_conv
import
gcn_norm
from
torch_geometric_autoscale.models
import
GCN
from
torch_geometric_autoscale
import
(
get_data
,
metis
,
permute
,
SubgraphLoader
,
compute_acc
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--root'
,
type
=
str
,
required
=
True
,
help
=
'Root directory of dataset storage.'
)
parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
12345
)
device
=
f
'cuda:
{
args
.
device
}
'
if
torch
.
cuda
.
is_available
()
else
'cpu'
data
,
in_channels
,
out_channels
=
get_data
(
args
.
root
,
name
=
'cora'
)
# Pre-process adjacency matrix for GCN:
data
.
adj_t
=
gcn_norm
(
data
.
adj_t
,
add_self_loops
=
True
)
# Pre-partition the graph using Metis:
perm
,
ptr
=
metis
(
data
.
adj_t
,
num_parts
=
40
,
log
=
True
)
data
=
permute
(
data
,
perm
,
log
=
True
)
loader
=
SubgraphLoader
(
data
,
ptr
,
batch_size
=
10
,
shuffle
=
True
)
model
=
GCN
(
num_nodes
=
data
.
num_nodes
,
in_channels
=
in_channels
,
hidden_channels
=
16
,
out_channels
=
out_channels
,
num_layers
=
2
,
dropout
=
0.5
,
drop_input
=
True
,
pool_size
=
2
,
buffer_size
=
1000
,
).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
([
dict
(
params
=
model
.
reg_modules
.
parameters
(),
weight_decay
=
5e-4
),
dict
(
params
=
model
.
nonreg_modules
.
parameters
(),
weight_decay
=
0
)
],
lr
=
0.01
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
train
(
data
,
model
,
loader
,
optimizer
):
model
.
train
()
for
batch
,
batch_size
,
n_id
,
offset
,
count
in
loader
:
batch
=
batch
.
to
(
device
)
train_mask
=
batch
.
train_mask
[:
batch_size
]
optimizer
.
zero_grad
()
out
=
model
(
batch
.
x
,
batch
.
adj_t
,
batch_size
,
n_id
,
offset
,
count
)
loss
=
criterion
(
out
[
train_mask
],
batch
.
y
[:
batch_size
][
train_mask
])
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
@
torch
.
no_grad
()
def
test
(
data
,
model
):
model
.
eval
()
out
=
model
(
data
.
x
.
to
(
model
.
device
),
data
.
adj_t
.
to
(
model
.
device
)).
cpu
()
train_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
train_mask
)
val_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
val_mask
)
test_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
test_mask
)
return
train_acc
,
val_acc
,
test_acc
test
(
data
,
model
)
# Fill history.
best_val_acc
=
test_acc
=
0
for
epoch
in
range
(
1
,
201
):
train
(
data
,
model
,
loader
,
optimizer
)
train_acc
,
val_acc
,
tmp_test_acc
=
test
(
data
,
model
)
if
val_acc
>
best_val_acc
:
best_val_acc
=
val_acc
test_acc
=
tmp_test_acc
print
(
f
'Epoch:
{
epoch
:
03
d
}
, Train:
{
train_acc
:.
4
f
}
, Val:
{
val_acc
:.
4
f
}
'
f
'Test:
{
tmp_test_acc
:.
4
f
}
, Final:
{
test_acc
:.
4
f
}
'
)
examples/train_gin.py
0 → 100644
View file @
91efc915
torch_geometric_autoscale/loader.py
View file @
91efc915
...
...
@@ -8,7 +8,7 @@ from torch.utils.data import DataLoader
from
torch_sparse
import
SparseTensor
from
torch_geometric.data
import
Data
relabel_fn
=
torch
.
ops
.
scaling_gnns
.
relabel_one_hop
relabel_fn
=
torch
.
ops
.
torch_geometric_autoscale
.
relabel_one_hop
class
SubData
(
NamedTuple
):
...
...
torch_geometric_autoscale/metis.py
View file @
91efc915
...
...
@@ -7,6 +7,8 @@ from torch import Tensor
from
torch_sparse
import
SparseTensor
from
torch_geometric.data
import
Data
partition_fn
=
torch
.
ops
.
torch_sparse
.
partition
def
metis
(
adj_t
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
,
log
:
bool
=
True
)
->
Tuple
[
Tensor
,
Tensor
]:
...
...
@@ -22,8 +24,7 @@ def metis(adj_t: SparseTensor, num_parts: int, recursive: bool = False,
perm
,
ptr
=
torch
.
arange
(
num_nodes
),
torch
.
tensor
([
0
,
num_nodes
])
else
:
rowptr
,
col
,
_
=
adj_t
.
csr
()
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
None
,
num_parts
,
recursive
)
cluster
=
partition_fn
(
rowptr
,
col
,
None
,
num_parts
,
recursive
)
cluster
,
perm
=
cluster
.
sort
()
ptr
=
torch
.
ops
.
torch_sparse
.
ind2ptr
(
cluster
,
num_parts
)
...
...
torch_geometric_autoscale/models/__init__.py
View file @
91efc915
from
.base
import
ScalableGNN
from
.gcn
import
GCN
from
.gat
import
GAT
from
.appnp
import
APPNP
from
.gcn2
import
GCN2
from
.gin
import
GIN
from
.pna
import
PNA
from
.pna_jk
import
PNA_JK
#
from .gat import GAT
#
from .appnp import APPNP
#
from .gcn2 import GCN2
#
from .gin import GIN
#
from .pna import PNA
#
from .pna_jk import PNA_JK
__all__
=
[
'ScalableGNN'
,
'GCN'
,
'GAT'
,
'APPNP'
,
'GCN2'
,
'GIN'
,
'PNA'
,
'PNA_JK'
,
#
'GAT',
#
'APPNP',
#
'GCN2',
#
'GIN',
#
'PNA',
#
'PNA_JK',
]
torch_geometric_autoscale/models/base.py
View file @
91efc915
...
...
@@ -164,5 +164,5 @@ class ScalableGNN(torch.nn.Module):
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
state
:
Dict
[
Any
])
->
Tensor
:
state
:
Dict
[
str
,
Any
])
->
Tensor
:
raise
NotImplementedError
torch_geometric_autoscale/models/gcn.py
View file @
91efc915
from
typing
import
Optional
,
Dict
,
Any
from
typing
import
Optional
import
torch
from
torch
import
Tensor
import
torch.nn.functional
as
F
from
torch.nn
import
ModuleList
,
BatchNorm1d
from
torch.nn
import
ModuleList
,
Linear
,
BatchNorm1d
from
torch_sparse
import
SparseTensor
from
torch_geometric.nn
import
GCNConv
from
scaling_gnns
.models
.base2
import
ScalableGNN
from
torch_geometric_autoscale
.models
import
ScalableGNN
class
GCN
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
,
dropout
:
float
=
0.0
,
drop_input
:
bool
=
True
,
batch_norm
:
bool
=
False
,
residual
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
residual
:
bool
=
False
,
linear
:
bool
=
False
,
pool_size
:
Optional
[
int
]
=
None
,
buffer_size
:
Optional
[
int
]
=
None
,
device
=
None
):
super
(
GCN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
,
buffer_size
,
device
)
...
...
@@ -25,29 +26,43 @@ class GCN(ScalableGNN):
self
.
drop_input
=
drop_input
self
.
batch_norm
=
batch_norm
self
.
residual
=
residual
self
.
linear
=
linear
if
linear
:
self
.
lins
=
ModuleList
()
self
.
lins
.
append
(
Linear
(
in_channels
,
hidden_channels
))
self
.
lins
.
append
(
Linear
(
hidden_channels
,
out_channels
))
self
.
convs
=
ModuleList
()
for
i
in
range
(
num_layers
):
in_dim
=
in_channels
if
i
==
0
else
hidden_channels
out_dim
=
out_channels
if
i
==
num_layers
-
1
else
hidden_channels
in_dim
=
out_dim
=
hidden_channels
if
i
==
0
and
not
linear
:
in_dim
=
in_channels
if
i
==
num_layers
-
1
and
not
linear
:
out_dim
=
out_channels
conv
=
GCNConv
(
in_dim
,
out_dim
,
normalize
=
False
)
self
.
convs
.
append
(
conv
)
self
.
bns
=
ModuleList
()
for
i
in
range
(
num_layers
-
1
):
for
i
in
range
(
num_layers
):
bn
=
BatchNorm1d
(
hidden_channels
)
self
.
bns
.
append
(
bn
)
@
property
def
reg_modules
(
self
):
return
ModuleList
(
list
(
self
.
convs
[:
-
1
])
+
list
(
self
.
bns
))
if
self
.
linear
:
return
ModuleList
(
list
(
self
.
convs
)
+
list
(
self
.
bns
))
else
:
return
ModuleList
(
list
(
self
.
convs
[:
-
1
])
+
list
(
self
.
bns
))
@
property
def
nonreg_modules
(
self
):
return
self
.
convs
[
-
1
:]
return
self
.
lins
if
self
.
linear
else
self
.
convs
[
-
1
:]
def
reset_parameters
(
self
):
super
(
GCN
,
self
).
reset_parameters
()
for
lin
in
self
.
lins
:
lin
.
reset_parameters
()
for
conv
in
self
.
convs
:
conv
.
reset_parameters
()
for
bn
in
self
.
bns
:
...
...
@@ -61,6 +76,10 @@ class GCN(ScalableGNN):
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
if
self
.
linear
:
x
=
self
.
lins
[
0
](
x
).
relu_
()
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
for
conv
,
bn
,
hist
in
zip
(
self
.
convs
[:
-
1
],
self
.
bns
,
self
.
histories
):
h
=
conv
(
x
,
adj_t
)
if
self
.
batch_norm
:
...
...
@@ -71,23 +90,41 @@ class GCN(ScalableGNN):
x
=
self
.
push_and_pull
(
hist
,
x
,
batch_size
,
n_id
,
offset
,
count
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
convs
[
-
1
](
x
,
adj_t
)
return
x
h
=
self
.
convs
[
-
1
](
x
,
adj_t
)
if
not
self
.
linear
:
return
h
if
self
.
batch_norm
:
h
=
self
.
bns
[
-
1
](
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
=
h
.
relu_
()
h
=
F
.
dropout
(
h
,
p
=
self
.
dropout
,
training
=
self
.
training
)
return
self
.
lins
[
1
](
h
)
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
state
:
Dict
[
Any
])
->
Tensor
:
if
layer
==
0
and
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
if
layer
==
0
:
if
self
.
drop_input
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
if
self
.
linear
:
x
=
self
.
lins
[
0
](
x
).
relu_
()
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
convs
[
layer
](
x
,
adj_t
)
if
layer
<
self
.
num_layers
-
1
:
if
layer
<
self
.
num_layers
-
1
or
self
.
linear
:
if
self
.
batch_norm
:
h
=
self
.
bns
[
layer
](
h
)
if
self
.
residual
and
h
.
size
(
-
1
)
==
x
.
size
(
-
1
):
h
+=
x
[:
h
.
size
(
0
)]
h
=
h
.
relu_
()
if
self
.
linear
:
h
=
F
.
dropout
(
h
,
p
=
self
.
dropout
,
training
=
self
.
training
)
h
=
self
.
lins
[
1
](
h
)
return
h
torch_geometric_autoscale/pool.py
View file @
91efc915
...
...
@@ -4,9 +4,9 @@ import torch
from
torch
import
Tensor
from
torch.cuda
import
Stream
synchronize
=
torch
.
ops
.
scaling_gnns
.
synchronize
read_async
=
torch
.
ops
.
scaling_gnns
.
read_async
write_async
=
torch
.
ops
.
scaling_gnns
.
write_async
synchronize
=
torch
.
ops
.
torch_geometric_autoscale
.
synchronize
read_async
=
torch
.
ops
.
torch_geometric_autoscale
.
read_async
write_async
=
torch
.
ops
.
torch_geometric_autoscale
.
write_async
class
AsyncIOPool
(
torch
.
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment