Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
ec590171
"docs/source/vscode:/vscode.git/clone" did not exist on "039a711da1c7833db3b45d212a4aaab9cfbf7a09"
Commit
ec590171
authored
Jun 08, 2021
by
rusty1s
Browse files
update training GIN script
parent
a73bb262
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
46 deletions
+45
-46
examples/train_gin.py
examples/train_gin.py
+45
-46
No files found.
examples/train_gin.py
View file @
ec590171
import
argparse
import
torch
from
torch.optim.lr_scheduler
import
ReduceLROnPlateau
from
torch
import
Tensor
from
torch.optim.lr_scheduler
import
ReduceLROnPlateau
as
ReduceLR
from
torch.nn
import
Identity
,
Sequential
,
Linear
,
ReLU
,
BatchNorm1d
from
torch_sparse
import
SparseTensor
import
torch_geometric.transforms
as
T
from
torch_geometric.nn
import
GINConv
from
torch_geometric.data
import
DataLoader
from
torch_geometric.datasets
import
GNNBenchmarkDataset
as
SBM
from
torch_geometric_autoscale
import
get_data
from
torch_geometric_autoscale
import
metis
,
permute
from
torch_geometric_autoscale.models
import
ScalableGNN
from
torch_geometric_autoscale
import
(
get_data
,
SubgraphLoader
,
EvalSubgraphLoader
)
from
torch_geometric_autoscale
import
SubgraphLoader
,
EvalSubgraphLoader
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--root'
,
type
=
str
,
required
=
True
,
...
...
@@ -23,32 +26,33 @@ device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
data
,
in_channels
,
out_channels
=
get_data
(
args
.
root
,
name
=
'CLUSTER'
)
train_dataset
=
SBM
(
f
'
{
args
.
root
}
/SBM'
,
name
=
'CLUSTER'
,
split
=
'train'
,
pre_transform
=
T
.
ToSparseTensor
())
val_dataset
=
SBM
(
f
'
{
args
.
root
}
/SBM'
,
name
=
'CLUSTER'
,
split
=
'val'
,
pre_transform
=
T
.
ToSparseTensor
())
test_dataset
=
SBM
(
f
'
{
args
.
root
}
/SBM'
,
name
=
'CLUSTER'
,
split
=
'test'
,
pre_transform
=
T
.
ToSparseTensor
())
val_loader
=
DataLoader
(
val_dataset
,
batch_size
=
512
)
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
512
)
ptr
=
[
0
]
for
d
in
train_dataset
:
# Minimize inter-connectivity between batches:
ptr
+=
[
ptr
[
-
1
]
+
d
.
num_nodes
//
2
,
ptr
[
-
1
]
+
d
.
num_nodes
]
ptr
=
torch
.
tensor
(
ptr
)
# Pre-partition the graph using Metis:
perm
,
ptr
=
metis
(
data
.
adj_t
,
num_parts
=
10000
,
log
=
True
)
data
=
permute
(
data
,
perm
,
log
=
True
)
train_loader
=
SubgraphLoader
(
data
,
ptr
,
batch_size
=
256
,
shuffle
=
True
,
num_workers
=
6
,
persistent_workers
=
True
)
eval_loader
=
EvalSubgraphLoader
(
data
,
ptr
,
batch_size
=
256
)
# We use the regular PyTorch Geometric dataset for evaluation:
kwargs
=
{
'name'
:
'CLUSTER'
,
'pre_transform'
:
T
.
ToSparseTensor
()}
val_dataset
=
SBM
(
f
'
{
args
.
root
}
/SBM'
,
split
=
'val'
,
**
kwargs
)
test_dataset
=
SBM
(
f
'
{
args
.
root
}
/SBM'
,
split
=
'test'
,
**
kwargs
)
val_loader
=
DataLoader
(
val_dataset
,
batch_size
=
512
)
test_loader
=
DataLoader
(
test_dataset
,
batch_size
=
512
)
class
GIN
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
,
in_channels
,
hidden_channels
,
out_channels
,
num_layers
):
super
(
GIN
,
self
).
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
=
2
,
buffer_size
=
200000
)
# We define our own GAS+GIN module:
class
GIN
(
ScalableGNN
):
def
__init__
(
self
,
num_nodes
:
int
,
in_channels
:
int
,
hidden_channels
:
int
,
out_channels
:
int
,
num_layers
:
int
):
super
().
__init__
(
num_nodes
,
hidden_channels
,
num_layers
,
pool_size
=
2
,
buffer_size
=
60000
)
# pool_size determines the number of pinned CPU buffers
# buffer_size determines the size of pinned CPU buffers,
# i.e. the maximum number of out-of-mini-batch nodes
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
lins
=
torch
.
nn
.
ModuleList
()
...
...
@@ -64,46 +68,43 @@ class GIN(ScalableGNN):
mlp
=
Sequential
(
Linear
(
hidden_channels
,
hidden_channels
),
BatchNorm1d
(
hidden_channels
,
track_running_stats
=
False
),
ReLU
(
inplace
=
True
),
ReLU
(),
Linear
(
hidden_channels
,
hidden_channels
),
ReLU
(),
)
self
.
mlps
.
append
(
mlp
)
def
forward
(
self
,
x
,
adj_t
,
batch_size
=
None
,
n_id
=
None
,
offset
=
None
,
count
=
None
):
reg
=
0
def
forward
(
self
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
*
args
):
x
=
self
.
lins
[
0
](
x
).
relu_
()
for
i
,
(
conv
,
mlp
,
hist
)
in
enumerate
(
zip
(
self
.
convs
[:
-
1
],
self
.
mlps
[:
-
1
],
self
.
histories
)
):
reg
=
0
it
=
zip
(
self
.
convs
[:
-
1
],
self
.
mlps
[:
-
1
],
self
.
histories
)
for
i
,
(
conv
,
mlp
,
history
)
in
enumerate
(
it
):
h
=
conv
((
x
,
x
[:
adj_t
.
size
(
0
)]),
adj_t
)
#
Enforc
e Lipschitz continuity via regularization (part 1):
#
Regulariz
e Lipschitz continuity via regularization (part 1):
if
i
>
0
and
self
.
training
:
eps
=
0.01
*
torch
.
randn_like
(
h
)
approx
=
mlp
(
h
+
eps
)
approx
=
mlp
(
h
+
0.1
*
torch
.
randn_like
(
h
))
h
=
mlp
(
h
)
#
Enforc
e Lipschitz continuity via regularization (part 2):
#
Regulariz
e Lipschitz continuity via regularization (part 2):
if
i
>
0
and
self
.
training
:
diff
=
(
h
-
approx
).
norm
(
dim
=-
1
)
reg
+=
diff
.
mean
()
/
len
(
self
.
histories
)
h
+=
x
[:
h
.
size
(
0
)]
x
=
self
.
push_and_pull
(
hist
,
h
,
batch_size
,
n_id
,
offset
,
count
)
h
+=
x
[:
h
.
size
(
0
)]
# Simple skip-connection
x
=
self
.
push_and_pull
(
hist
ory
,
h
,
*
args
)
h
=
self
.
convs
[
-
1
]((
x
,
x
[:
adj_t
.
size
(
0
)]),
adj_t
)
h
=
self
.
mlps
[
-
1
](
h
)
h
+=
x
[:
h
.
size
(
0
)]
x
=
self
.
lins
[
1
](
h
)
return
self
.
lins
[
1
](
h
)
,
reg
return
x
,
reg
@
torch
.
no_grad
()
def
forward_layer
(
self
,
layer
,
x
,
adj_t
,
state
):
def
forward_layer
(
self
,
layer
:
int
,
x
:
Tensor
,
adj_t
:
SparseTensor
,
state
):
if
layer
==
0
:
x
=
self
.
lins
[
0
](
x
).
relu_
()
...
...
@@ -118,7 +119,7 @@ class GIN(ScalableGNN):
model
=
GIN
(
num_nodes
=
train_dataset
.
data
.
num_nodes
,
num_nodes
=
data
.
num_nodes
,
in_channels
=
in_channels
,
hidden_channels
=
128
,
out_channels
=
out_channels
,
...
...
@@ -127,23 +128,20 @@ model = GIN(
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.001
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
scheduler
=
ReduceLROnPlateau
(
optimizer
,
mode
=
'max'
,
factor
=
0.5
,
patience
=
20
,
min_lr
=
1e-5
)
scheduler
=
ReduceLR
(
optimizer
,
'max'
,
factor
=
0.5
,
patience
=
20
,
min_lr
=
1e-5
)
def
train
(
model
,
loader
,
optimizer
):
model
.
train
()
total_loss
=
total_examples
=
0
for
batch
,
batch_size
,
n_id
,
offset
,
count
in
loader
:
for
batch
,
*
args
in
loader
:
batch
=
batch
.
to
(
model
.
device
)
optimizer
.
zero_grad
()
out
,
reg
=
model
(
batch
.
x
,
batch
.
adj_t
,
batch_size
,
n_id
,
offset
,
count
)
loss
=
criterion
(
out
,
batch
.
y
[:
batch_
size
])
+
reg
out
,
reg
=
model
(
batch
.
x
,
batch
.
adj_t
,
*
args
)
loss
=
criterion
(
out
,
batch
.
y
[:
out
.
size
(
0
)
])
+
reg
loss
.
backward
()
optimizer
.
step
()
total_loss
+=
float
(
loss
)
*
int
(
out
.
size
(
0
))
total_examples
+=
int
(
out
.
size
(
0
))
...
...
@@ -172,6 +170,7 @@ def mini_test(model, loader, y):
mini_test
(
model
,
eval_loader
,
data
.
y
)
# Fill history.
for
epoch
in
range
(
1
,
151
):
lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
loss
=
train
(
model
,
train_loader
,
optimizer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment