Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0f127637
Unverified
Commit
0f127637
authored
Aug 21, 2019
by
VoVAllen
Committed by
GitHub
Aug 21, 2019
Browse files
[Model] Early stop GAT (#750)
* Add early stop * add mxnet version * Poke ci
parent
77c58289
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
77 additions
and
7 deletions
+77
-7
examples/mxnet/gat/train.py
examples/mxnet/gat/train.py
+7
-5
examples/mxnet/gat/utils.py
examples/mxnet/gat/utils.py
+29
-0
examples/pytorch/gat/train.py
examples/pytorch/gat/train.py
+12
-2
examples/pytorch/gat/utils.py
examples/pytorch/gat/utils.py
+29
-0
No files found.
examples/mxnet/gat/train.py
View file @
0f127637
...
@@ -18,7 +18,7 @@ import numpy as np
...
@@ -18,7 +18,7 @@ import numpy as np
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.data
import
register_data_args
,
load_data
from
dgl.data
import
register_data_args
,
load_data
from
gat
import
GAT
from
gat
import
GAT
from
utils
import
EarlyStopping
def
elu
(
data
):
def
elu
(
data
):
return
mx
.
nd
.
LeakyReLU
(
data
,
act_type
=
'elu'
)
return
mx
.
nd
.
LeakyReLU
(
data
,
act_type
=
'elu'
)
...
@@ -75,6 +75,7 @@ def main(args):
...
@@ -75,6 +75,7 @@ def main(args):
args
.
alpha
,
args
.
alpha
,
args
.
residual
)
args
.
residual
)
stopper
=
EarlyStopping
(
patience
=
100
)
model
.
initialize
(
ctx
=
ctx
)
model
.
initialize
(
ctx
=
ctx
)
# use optimizer
# use optimizer
...
@@ -95,10 +96,11 @@ def main(args):
...
@@ -95,10 +96,11 @@ def main(args):
dur
.
append
(
time
.
time
()
-
t0
)
dur
.
append
(
time
.
time
()
-
t0
)
print
(
"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}"
.
format
(
print
(
"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}"
.
format
(
epoch
,
loss
.
asnumpy
()[
0
],
np
.
mean
(
dur
),
n_edges
/
np
.
mean
(
dur
)
/
1000
))
epoch
,
loss
.
asnumpy
()[
0
],
np
.
mean
(
dur
),
n_edges
/
np
.
mean
(
dur
)
/
1000
))
if
epoch
%
100
==
0
:
val_accuracy
=
evaluate
(
model
,
features
,
labels
,
val_mask
)
val_accuracy
=
evaluate
(
model
,
features
,
labels
,
val_mask
)
print
(
"Validation Accuracy {:.4f}"
.
format
(
val_accuracy
))
print
(
"Validation Accuracy {:.4f}"
.
format
(
val_accuracy
))
if
stopper
.
step
(
val_accuracy
,
model
):
break
model
.
load_parameters
(
'model.param'
)
test_accuracy
=
evaluate
(
model
,
features
,
labels
,
test_mask
)
test_accuracy
=
evaluate
(
model
,
features
,
labels
,
test_mask
)
print
(
"Test Accuracy {:.4f}"
.
format
(
test_accuracy
))
print
(
"Test Accuracy {:.4f}"
.
format
(
test_accuracy
))
...
...
examples/mxnet/gat/utils.py
0 → 100644
View file @
0f127637
import
numpy
as
np
import
torch
class
EarlyStopping
:
def
__init__
(
self
,
patience
=
10
):
self
.
patience
=
patience
self
.
counter
=
0
self
.
best_score
=
None
self
.
early_stop
=
False
def
step
(
self
,
acc
,
model
):
score
=
acc
if
self
.
best_score
is
None
:
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
elif
score
<
self
.
best_score
:
self
.
counter
+=
1
print
(
f
'EarlyStopping counter:
{
self
.
counter
}
out of
{
self
.
patience
}
'
)
if
self
.
counter
>=
self
.
patience
:
self
.
early_stop
=
True
else
:
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
self
.
counter
=
0
return
self
.
early_stop
def
save_checkpoint
(
self
,
model
):
'''Saves model when validation loss decrease.'''
model
.
save_parameters
(
'model.param'
)
examples/pytorch/gat/train.py
View file @
0f127637
...
@@ -18,12 +18,15 @@ import torch.nn.functional as F
...
@@ -18,12 +18,15 @@ import torch.nn.functional as F
from
dgl
import
DGLGraph
from
dgl
import
DGLGraph
from
dgl.data
import
register_data_args
,
load_data
from
dgl.data
import
register_data_args
,
load_data
from
gat
import
GAT
from
gat
import
GAT
from
utils
import
EarlyStopping
def
accuracy
(
logits
,
labels
):
def
accuracy
(
logits
,
labels
):
_
,
indices
=
torch
.
max
(
logits
,
dim
=
1
)
_
,
indices
=
torch
.
max
(
logits
,
dim
=
1
)
correct
=
torch
.
sum
(
indices
==
labels
)
correct
=
torch
.
sum
(
indices
==
labels
)
return
correct
.
item
()
*
1.0
/
len
(
labels
)
return
correct
.
item
()
*
1.0
/
len
(
labels
)
def
evaluate
(
model
,
features
,
labels
,
mask
):
def
evaluate
(
model
,
features
,
labels
,
mask
):
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -32,6 +35,7 @@ def evaluate(model, features, labels, mask):
...
@@ -32,6 +35,7 @@ def evaluate(model, features, labels, mask):
labels
=
labels
[
mask
]
labels
=
labels
[
mask
]
return
accuracy
(
logits
,
labels
)
return
accuracy
(
logits
,
labels
)
def
main
(
args
):
def
main
(
args
):
# load and preprocess dataset
# load and preprocess dataset
data
=
load_data
(
args
)
data
=
load_data
(
args
)
...
@@ -45,7 +49,7 @@ def main(args):
...
@@ -45,7 +49,7 @@ def main(args):
n_edges
=
data
.
graph
.
number_of_edges
()
n_edges
=
data
.
graph
.
number_of_edges
()
print
(
"""----Data statistics------'
print
(
"""----Data statistics------'
#Edges %d
#Edges %d
#Classes %d
#Classes %d
#Train samples %d
#Train samples %d
#Val samples %d
#Val samples %d
#Test samples %d"""
%
#Test samples %d"""
%
...
@@ -85,12 +89,14 @@ def main(args):
...
@@ -85,12 +89,14 @@ def main(args):
args
.
alpha
,
args
.
alpha
,
args
.
residual
)
args
.
residual
)
print
(
model
)
print
(
model
)
stopper
=
EarlyStopping
(
patience
=
100
)
if
cuda
:
if
cuda
:
model
.
cuda
()
model
.
cuda
()
loss_fcn
=
torch
.
nn
.
CrossEntropyLoss
()
loss_fcn
=
torch
.
nn
.
CrossEntropyLoss
()
# use optimizer
# use optimizer
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
# initialize graph
# initialize graph
dur
=
[]
dur
=
[]
...
@@ -115,6 +121,8 @@ def main(args):
...
@@ -115,6 +121,8 @@ def main(args):
val_acc
=
accuracy
(
logits
[
val_mask
],
labels
[
val_mask
])
val_acc
=
accuracy
(
logits
[
val_mask
],
labels
[
val_mask
])
else
:
else
:
val_acc
=
evaluate
(
model
,
features
,
labels
,
val_mask
)
val_acc
=
evaluate
(
model
,
features
,
labels
,
val_mask
)
if
stopper
.
step
(
val_acc
,
model
):
break
print
(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
print
(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}"
.
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}"
.
...
@@ -122,9 +130,11 @@ def main(args):
...
@@ -122,9 +130,11 @@ def main(args):
val_acc
,
n_edges
/
np
.
mean
(
dur
)
/
1000
))
val_acc
,
n_edges
/
np
.
mean
(
dur
)
/
1000
))
print
()
print
()
model
.
load_state_dict
(
torch
.
load
(
'es_checkpoint.pt'
))
acc
=
evaluate
(
model
,
features
,
labels
,
test_mask
)
acc
=
evaluate
(
model
,
features
,
labels
,
test_mask
)
print
(
"Test Accuracy {:.4f}"
.
format
(
acc
))
print
(
"Test Accuracy {:.4f}"
.
format
(
acc
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'GAT'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'GAT'
)
...
...
examples/pytorch/gat/utils.py
0 → 100644
View file @
0f127637
import
numpy
as
np
import
torch
class
EarlyStopping
:
def
__init__
(
self
,
patience
=
10
):
self
.
patience
=
patience
self
.
counter
=
0
self
.
best_score
=
None
self
.
early_stop
=
False
def
step
(
self
,
acc
,
model
):
score
=
acc
if
self
.
best_score
is
None
:
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
elif
score
<
self
.
best_score
:
self
.
counter
+=
1
print
(
f
'EarlyStopping counter:
{
self
.
counter
}
out of
{
self
.
patience
}
'
)
if
self
.
counter
>=
self
.
patience
:
self
.
early_stop
=
True
else
:
self
.
best_score
=
score
self
.
save_checkpoint
(
model
)
self
.
counter
=
0
return
self
.
early_stop
def
save_checkpoint
(
self
,
model
):
'''Saves model when validation loss decrease.'''
torch
.
save
(
model
.
state_dict
(),
'es_checkpoint.pt'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment