Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
9309d91b
Commit
9309d91b
authored
Jun 07, 2021
by
rusty1s
Browse files
train_gcn2
parent
5a8b5eac
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
2 deletions
+98
-2
examples/README.md
examples/README.md
+8
-2
examples/train_gcn2.py
examples/train_gcn2.py
+90
-0
No files found.
examples/README.md
View file @
9309d91b
# Examples
Train
**GCN**
on
**Cora**
:
Train
[
**GCN**
](
https://arxiv.org/abs/1609.02907
)
on
**Cora**
:
```
python train_gcn.py --root=/tmp/datasets --device=0
```
Train
**GIN**
on
**Cluster**
:
Train
[
**GCN2**
](
https://arxiv.org/abs/1902.07153
)
on
**Cora**
:
```
python train_gcn2.py --root=/tmp/datasets --device=0
```
Train
[
**GIN**
](
https://arxiv.org/abs/1810.00826
)
on
**Cluster**
:
```
python train_gin.py --root=/tmp/datasets --device=0
...
...
examples/train_gcn2.py
0 → 100644
View file @
9309d91b
import
argparse
import
torch
from
torch_geometric.nn.conv.gcn_conv
import
gcn_norm
from
torch_geometric_autoscale.models
import
GCN2
from
torch_geometric_autoscale
import
metis
,
permute
,
SubgraphLoader
from
torch_geometric_autoscale
import
get_data
,
compute_acc
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--root'
,
type
=
str
,
required
=
True
,
help
=
'Root directory of dataset storage.'
)
parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
12345
)
device
=
f
'cuda:
{
args
.
device
}
'
if
torch
.
cuda
.
is_available
()
else
'cpu'
data
,
in_channels
,
out_channels
=
get_data
(
args
.
root
,
name
=
'cora'
)
# Pre-process adjacency matrix for GCN:
data
.
adj_t
=
gcn_norm
(
data
.
adj_t
,
add_self_loops
=
True
)
# Pre-partition the graph using Metis:
perm
,
ptr
=
metis
(
data
.
adj_t
,
num_parts
=
40
,
log
=
True
)
data
=
permute
(
data
,
perm
,
log
=
True
)
loader
=
SubgraphLoader
(
data
,
ptr
,
batch_size
=
20
,
shuffle
=
True
)
# Make use of the pre-defined GCN+GAS model:
model
=
GCN2
(
num_nodes
=
data
.
num_nodes
,
in_channels
=
in_channels
,
hidden_channels
=
64
,
out_channels
=
out_channels
,
num_layers
=
64
,
alpha
=
0.1
,
theta
=
0.5
,
shared_weights
=
True
,
dropout
=
0.6
,
drop_input
=
True
,
pool_size
=
2
,
# Number of pinned CPU buffers
buffer_size
=
500
,
# Size of pinned CPU buffers (max #out-of-batch nodes)
).
to
(
device
)
optimizer
=
torch
.
optim
.
Adam
([
dict
(
params
=
model
.
reg_modules
.
parameters
(),
weight_decay
=
0.01
),
dict
(
params
=
model
.
nonreg_modules
.
parameters
(),
weight_decay
=
5e-4
)
],
lr
=
0.01
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
train
(
model
,
loader
,
optimizer
):
model
.
train
()
for
batch
,
*
args
in
loader
:
batch
=
batch
.
to
(
model
.
device
)
optimizer
.
zero_grad
()
out
=
model
(
batch
.
x
,
batch
.
adj_t
,
*
args
)
train_mask
=
batch
.
train_mask
[:
out
.
size
(
0
)]
loss
=
criterion
(
out
[
train_mask
],
batch
.
y
[:
out
.
size
(
0
)][
train_mask
])
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1.0
)
optimizer
.
step
()
@
torch
.
no_grad
()
def
test
(
model
,
data
):
model
.
eval
()
# Full-batch inference since the graph is small
out
=
model
(
data
.
x
.
to
(
model
.
device
),
data
.
adj_t
.
to
(
model
.
device
)).
cpu
()
train_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
train_mask
)
val_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
val_mask
)
test_acc
=
compute_acc
(
out
,
data
.
y
,
data
.
test_mask
)
return
train_acc
,
val_acc
,
test_acc
test
(
model
,
data
)
# Fill the history.
best_val_acc
=
test_acc
=
0
for
epoch
in
range
(
1
,
501
):
train
(
model
,
loader
,
optimizer
)
train_acc
,
val_acc
,
tmp_test_acc
=
test
(
model
,
data
)
if
val_acc
>
best_val_acc
:
best_val_acc
=
val_acc
test_acc
=
tmp_test_acc
print
(
f
'Epoch:
{
epoch
:
03
d
}
, Train:
{
train_acc
:.
4
f
}
, Val:
{
val_acc
:.
4
f
}
, '
f
'Test:
{
tmp_test_acc
:.
4
f
}
, Final:
{
test_acc
:.
4
f
}
'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment