Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
7d560402
"src/vscode:/vscode.git/clone" did not exist on "b56f102765707ce34c7016409a8b901942edb977"
Commit
7d560402
authored
May 28, 2018
by
alexeib
Committed by
Myle Ott
Jun 15, 2018
Browse files
record end_of_epoch in checkpoint
parent
978c125a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
8 deletions
+14
-8
tests/test_train.py
tests/test_train.py
+6
-6
train.py
train.py
+8
-2
No files found.
tests/test_train.py
View file @
7d560402
...
@@ -13,9 +13,9 @@ from unittest.mock import MagicMock, patch
...
@@ -13,9 +13,9 @@ from unittest.mock import MagicMock, patch
import
train
import
train
def
mock_trainer
(
epoch
,
num_updates
):
def
mock_trainer
(
epoch
,
num_updates
,
end_of_epoch
):
trainer
=
MagicMock
()
trainer
=
MagicMock
()
trainer
.
load_checkpoint
.
return_value
=
{
'epoch'
:
epoch
}
trainer
.
load_checkpoint
.
return_value
=
{
'epoch'
:
epoch
,
'end_of_epoch'
:
end_of_
epoch
}
trainer
.
get_num_updates
.
return_value
=
num_updates
trainer
.
get_num_updates
.
return_value
=
num_updates
return
trainer
return
trainer
...
@@ -38,21 +38,21 @@ class TestLoadCheckpoint(unittest.TestCase):
...
@@ -38,21 +38,21 @@ class TestLoadCheckpoint(unittest.TestCase):
[
p
.
start
()
for
p
in
self
.
applied_patches
]
[
p
.
start
()
for
p
in
self
.
applied_patches
]
def
test_load_partial_checkpoint
(
self
):
def
test_load_partial_checkpoint
(
self
):
trainer
=
mock_trainer
(
2
,
200
)
trainer
=
mock_trainer
(
2
,
200
,
False
)
loader
=
mock_loader
(
150
)
loader
=
mock_loader
(
150
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
next
(
ds
),
50
)
self
.
assertEqual
(
next
(
ds
),
50
)
def
test_load_full_checkpoint
(
self
):
def
test_load_full_checkpoint
(
self
):
trainer
=
mock_trainer
(
2
,
150
)
trainer
=
mock_trainer
(
2
,
300
,
True
)
loader
=
mock_loader
(
150
)
loader
=
mock_loader
(
150
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
epoch
,
ds
=
train
.
load_checkpoint
(
MagicMock
(),
trainer
,
loader
)
self
.
assertEqual
(
epoch
,
2
)
self
.
assertEqual
(
epoch
,
3
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
self
.
assertEqual
(
next
(
iter
(
ds
)),
0
)
def
test_load_no_checkpoint
(
self
):
def
test_load_no_checkpoint
(
self
):
trainer
=
mock_trainer
(
0
,
0
)
trainer
=
mock_trainer
(
0
,
0
,
False
)
loader
=
mock_loader
(
150
)
loader
=
mock_loader
(
150
)
self
.
patches
[
'os.path.isfile'
].
return_value
=
False
self
.
patches
[
'os.path.isfile'
].
return_value
=
False
...
...
train.py
View file @
7d560402
...
@@ -280,6 +280,7 @@ def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
...
@@ -280,6 +280,7 @@ def save_checkpoint(trainer, args, epoch, end_of_epoch, val_loss):
'epoch'
:
epoch
,
'epoch'
:
epoch
,
'val_loss'
:
val_loss
,
'val_loss'
:
val_loss
,
'wall_time'
:
trainer
.
get_meter
(
'wall'
).
elapsed_time
,
'wall_time'
:
trainer
.
get_meter
(
'wall'
).
elapsed_time
,
'end_of_epoch'
:
end_of_epoch
,
}
}
if
end_of_epoch
and
not
args
.
no_epoch_checkpoints
:
if
end_of_epoch
and
not
args
.
no_epoch_checkpoints
:
...
@@ -314,9 +315,10 @@ def load_checkpoint(args, trainer, train_dataloader):
...
@@ -314,9 +315,10 @@ def load_checkpoint(args, trainer, train_dataloader):
extra_state
=
trainer
.
load_checkpoint
(
checkpoint_path
)
extra_state
=
trainer
.
load_checkpoint
(
checkpoint_path
)
if
extra_state
is
not
None
:
if
extra_state
is
not
None
:
epoch
=
extra_state
[
'epoch'
]
epoch
=
extra_state
[
'epoch'
]
end_of_epoch
=
extra_state
.
get
(
'end_of_epoch'
,
True
)
trainer_updates
=
trainer
.
get_num_updates
()
trainer_updates
=
trainer
.
get_num_updates
()
print
(
'| loaded checkpoint {} (epoch {}
@ {} updates
)'
.
format
(
checkpoint_path
,
epoch
,
trainer_updates
))
print
(
'| loaded checkpoint {} (epoch {})'
.
format
(
checkpoint_path
,
epoch
))
trainer
.
lr_step
(
epoch
)
trainer
.
lr_step
(
epoch
)
updates
=
0
updates
=
0
...
@@ -324,14 +326,18 @@ def load_checkpoint(args, trainer, train_dataloader):
...
@@ -324,14 +326,18 @@ def load_checkpoint(args, trainer, train_dataloader):
ds
=
next
(
train_dataloader
)
ds
=
next
(
train_dataloader
)
updates
+=
len
(
ds
)
updates
+=
len
(
ds
)
if
ds
is
not
None
and
updates
>
trainer_updates
:
if
not
end_of_epoch
and
ds
is
not
None
and
updates
>
trainer_updates
:
completed_batches
=
len
(
ds
)
-
(
updates
-
trainer_updates
)
completed_batches
=
len
(
ds
)
-
(
updates
-
trainer_updates
)
assert
completed_batches
>=
0
assert
completed_batches
>=
0
ds
=
iter
(
ds
)
ds
=
iter
(
ds
)
print
(
'| resuming from batch {}'
.
format
(
completed_batches
+
1
))
# consume completed batches
# consume completed batches
next
(
islice
(
ds
,
completed_batches
,
completed_batches
),
None
)
next
(
islice
(
ds
,
completed_batches
,
completed_batches
),
None
)
else
:
else
:
if
not
end_of_epoch
:
print
(
'| WARNING: checkpoint is not at end of epoch'
)
ds
=
next
(
train_dataloader
)
ds
=
next
(
train_dataloader
)
epoch
+=
1
epoch
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment