Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
6643d525
Commit
6643d525
authored
May 30, 2018
by
Myle Ott
Browse files
Use symlinks for redundant checkpoints
parent
24d7de44
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
28 deletions
+38
-28
train.py
train.py
+38
-28
No files found.
train.py
View file @
6643d525
...
@@ -99,8 +99,8 @@ def main(args):
...
@@ -99,8 +99,8 @@ def main(args):
lr
=
trainer
.
lr_step
(
epoch
,
first_val_loss
)
lr
=
trainer
.
lr_step
(
epoch
,
first_val_loss
)
# save checkpoint
# save checkpoint
if
not
args
.
no_save
and
epoch
%
args
.
save_interval
==
0
:
if
epoch
%
args
.
save_interval
==
0
:
save_checkpoint
(
trainer
,
args
,
epoch
,
end_of_epoch
=
True
,
val_loss
=
first_val_loss
)
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
=
True
,
val_loss
=
first_val_loss
)
epoch
+=
1
epoch
+=
1
next_ds
=
next
(
train_dataloader
)
next_ds
=
next
(
train_dataloader
)
...
@@ -163,10 +163,9 @@ def train(args, trainer, itr, epoch, dataset):
...
@@ -163,10 +163,9 @@ def train(args, trainer, itr, epoch, dataset):
trainer
.
get_meter
(
'wps'
).
reset
()
trainer
.
get_meter
(
'wps'
).
reset
()
num_updates
=
trainer
.
get_num_updates
()
num_updates
=
trainer
.
get_num_updates
()
if
not
args
.
no_save
and
(
args
.
save_interval_updates
or
0
)
>
0
and
\
if
args
.
save_interval_updates
>
0
and
num_updates
%
args
.
save_interval_updates
==
0
:
num_updates
%
args
.
save_interval_updates
==
0
:
first_val_loss
=
val_loss
(
args
,
trainer
,
dataset
,
epoch
,
num_updates
)
first_val_loss
=
val_loss
(
args
,
trainer
,
dataset
,
epoch
,
num_updates
)
save_checkpoint
(
trainer
,
args
,
epoch
,
end_of_epoch
=
False
,
val_loss
=
first_val_loss
)
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
=
False
,
val_loss
=
first_val_loss
)
if
num_updates
>=
max_update
:
if
num_updates
>=
max_update
:
break
break
...
@@ -280,38 +279,49 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None):
...
@@ -280,38 +279,49 @@ def val_loss(args, trainer, dataset, epoch, num_updates=None):
return
losses
[
0
]
if
len
(
losses
)
>
0
else
None
return
losses
[
0
]
if
len
(
losses
)
>
0
else
None
def
save_checkpoint
(
trainer
,
args
,
epoch
,
end_of_epoch
,
val_loss
):
def
save_checkpoint
(
args
,
trainer
,
epoch
,
end_of_epoch
,
val_loss
):
if
args
.
no_save
or
args
.
distributed_rank
>
0
:
return
updates
=
trainer
.
get_num_updates
()
checkpoint_conds
=
collections
.
OrderedDict
()
checkpoint_conds
[
'checkpoint{}.pt'
.
format
(
epoch
)]
=
(
end_of_epoch
and
not
args
.
no_epoch_checkpoints
and
epoch
%
args
.
save_interval
==
0
)
checkpoint_conds
[
'checkpoint_{}_{}.pt'
.
format
(
epoch
,
updates
)]
=
(
not
end_of_epoch
and
args
.
save_interval_updates
>
0
and
updates
%
args
.
save_interval_updates
==
0
)
checkpoint_conds
[
'checkpoint_best.pt'
]
=
(
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
)
checkpoint_conds
[
'checkpoint_last.pt'
]
=
True
# keep this last so that it's a symlink
save_checkpoint
.
best
=
min
(
val_loss
,
getattr
(
save_checkpoint
,
'best'
,
val_loss
))
extra_state
=
{
extra_state
=
{
'best'
:
save_checkpoint
.
best
,
'end_of_epoch'
:
end_of_epoch
,
'epoch'
:
epoch
,
'epoch'
:
epoch
,
'val_loss'
:
val_loss
,
'val_loss'
:
val_loss
,
'wall_time'
:
trainer
.
get_meter
(
'wall'
).
elapsed_time
,
'wall_time'
:
trainer
.
get_meter
(
'wall'
).
elapsed_time
,
'end_of_epoch'
:
end_of_epoch
,
}
}
if
end_of_epoch
and
not
args
.
no_epoch_checkpoints
:
checkpoints
=
[
os
.
path
.
join
(
args
.
save_dir
,
fn
)
for
fn
,
cond
in
checkpoint_conds
.
items
()
if
cond
]
epoch_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint{}.pt'
.
format
(
epoch
))
if
len
(
checkpoints
)
>
0
:
trainer
.
save_checkpoint
(
epoch_filename
,
extra_state
)
for
fn
in
checkpoints
:
elif
not
end_of_epoch
and
args
.
keep_interval_updates
>
0
:
if
os
.
path
.
exists
(
fn
):
checkpoint_filename
=
os
.
path
.
join
(
args
.
save_dir
,
os
.
remove
(
fn
)
'checkpoint_{}_{}.pt'
.
format
(
epoch
,
trainer
.
get_num_updates
()))
trainer
.
save_checkpoint
(
checkpoints
[
0
],
extra_state
)
trainer
.
save_checkpoint
(
checkpoint_filename
,
extra_state
)
for
fn
in
checkpoints
[
1
:]:
# remove old checkpoints
os
.
symlink
(
os
.
path
.
basename
(
checkpoints
[
0
]),
fn
)
checkpoints
=
checkpoint_paths
(
args
.
save_dir
,
pattern
=
r
'checkpoint_\d+_(\d+)\.pt'
)
# checkpoints are sorted in descending order
if
not
end_of_epoch
and
args
.
keep_interval_updates
>
0
:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints
=
utils
.
checkpoint_paths
(
args
.
save_dir
,
pattern
=
r
'checkpoint_\d+_(\d+)\.pt'
)
for
old_chk
in
checkpoints
[
args
.
keep_interval_updates
:]:
for
old_chk
in
checkpoints
[
args
.
keep_interval_updates
:]:
os
.
remove
(
old_chk
)
os
.
remove
(
old_chk
)
assert
val_loss
is
not
None
if
not
hasattr
(
save_checkpoint
,
'best'
)
or
val_loss
<
save_checkpoint
.
best
:
save_checkpoint
.
best
=
val_loss
best_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_best.pt'
)
extra_state
[
'best'
]
=
val_loss
trainer
.
save_checkpoint
(
best_filename
,
extra_state
)
extra_state
[
'best'
]
=
save_checkpoint
.
best
last_filename
=
os
.
path
.
join
(
args
.
save_dir
,
'checkpoint_last.pt'
)
trainer
.
save_checkpoint
(
last_filename
,
extra_state
)
def
load_checkpoint
(
args
,
trainer
,
train_dataloader
):
def
load_checkpoint
(
args
,
trainer
,
train_dataloader
):
os
.
makedirs
(
args
.
save_dir
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
save_dir
,
exist_ok
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment