Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
9c102784
Commit
9c102784
authored
Aug 24, 2018
by
Myle Ott
Browse files
Add training wall time meter
parent
f84e1ed4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
58 deletions
+76
-58
fairseq/trainer.py
fairseq/trainer.py
+75
-58
train.py
train.py
+1
-0
No files found.
fairseq/trainer.py
View file @
9c102784
...
@@ -16,7 +16,7 @@ from itertools import chain
...
@@ -16,7 +16,7 @@ from itertools import chain
import
torch
import
torch
from
fairseq
import
distributed_utils
,
optim
,
utils
from
fairseq
import
distributed_utils
,
optim
,
utils
from
fairseq.meters
import
AverageMeter
,
TimeMeter
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
,
TimeMeter
from
fairseq.optim
import
lr_scheduler
from
fairseq.optim
import
lr_scheduler
...
@@ -54,6 +54,7 @@ class Trainer(object):
...
@@ -54,6 +54,7 @@ class Trainer(object):
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
meters
[
'wall'
]
=
TimeMeter
()
# wall time in seconds
self
.
meters
[
'wall'
]
=
TimeMeter
()
# wall time in seconds
self
.
meters
[
'train_wall'
]
=
StopwatchMeter
()
# train wall time in seconds
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_flat_grads
=
None
self
.
_flat_grads
=
None
...
@@ -109,9 +110,14 @@ class Trainer(object):
...
@@ -109,9 +110,14 @@ class Trainer(object):
self
.
meters
=
extra_state
[
'train_meters'
]
self
.
meters
=
extra_state
[
'train_meters'
]
del
extra_state
[
'train_meters'
]
del
extra_state
[
'train_meters'
]
# reset TimeMeters, since their start times don't make sense anymore
for
meter
in
self
.
meters
.
values
():
if
isinstance
(
meter
,
TimeMeter
):
meter
.
reset
()
return
extra_state
return
extra_state
def
train_step
(
self
,
sample
,
update_params
=
True
):
def
train_step
(
self
,
sample
,
update_params
=
True
,
dummy_batch
=
False
):
"""Do forward, backward and parameter update."""
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
# reproducible results when resuming from checkpoints
...
@@ -119,6 +125,9 @@ class Trainer(object):
...
@@ -119,6 +125,9 @@ class Trainer(object):
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
if
not
dummy_batch
:
self
.
meters
[
'train_wall'
].
start
()
# forward and backward pass
# forward and backward pass
sample
=
self
.
_prepare_sample
(
sample
)
sample
=
self
.
_prepare_sample
(
sample
)
loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
)
loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
)
...
@@ -132,62 +141,70 @@ class Trainer(object):
...
@@ -132,62 +141,70 @@ class Trainer(object):
# update parameters
# update parameters
if
update_params
:
if
update_params
:
# gather logging outputs from all replicas
agg_logging_output
=
self
.
_update_params
()
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
ooms_bwd
=
self
.
_buffered_stats
[
'ooms_bwd'
]
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
=
map
(
lambda
l
:
list
(
chain
.
from_iterable
(
l
)),
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
)
))
)
ooms_fwd
=
sum
(
ooms_fwd
)
ooms_bwd
=
sum
(
ooms_bwd
)
if
ooms_fwd
==
self
.
args
.
distributed_world_size
:
print
(
'| WARNING: OOM in all workers, skipping batch'
)
self
.
zero_grad
()
return
None
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
try
:
# all-reduce and rescale gradients, then take an optimization step
grad_norm
=
self
.
_all_reduce_and_rescale
(
grad_denom
)
self
.
_opt
()
# update meters
self
.
meters
[
'wps'
].
update
(
ntokens
)
self
.
meters
[
'ups'
].
update
(
1.
)
self
.
meters
[
'wpb'
].
update
(
ntokens
)
self
.
meters
[
'bsz'
].
update
(
nsentences
)
if
grad_norm
is
not
None
:
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
# update loss meters for training
if
'loss'
in
agg_logging_output
:
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
# criterions can optionally log the NLL loss too
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
except
OverflowError
as
e
:
self
.
zero_grad
()
print
(
'| WARNING: overflow detected, '
+
str
(
e
))
self
.
clear_buffered_stats
()
return
agg_logging_output
else
:
else
:
return
None
# buffering updates
agg_logging_output
=
None
# buffering updates
if
not
dummy_batch
:
self
.
meters
[
'train_wall'
].
stop
()
return
agg_logging_output
def
_update_params
(
self
):
# gather logging outputs from all replicas
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
ooms_bwd
=
self
.
_buffered_stats
[
'ooms_bwd'
]
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
=
map
(
lambda
l
:
list
(
chain
.
from_iterable
(
l
)),
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
)
))
)
ooms_fwd
=
sum
(
ooms_fwd
)
ooms_bwd
=
sum
(
ooms_bwd
)
if
ooms_fwd
==
self
.
args
.
distributed_world_size
:
print
(
'| WARNING: OOM in all workers, skipping batch'
)
self
.
zero_grad
()
return
None
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
try
:
# all-reduce and rescale gradients, then take an optimization step
grad_norm
=
self
.
_all_reduce_and_rescale
(
grad_denom
)
self
.
_opt
()
# update meters
self
.
meters
[
'wps'
].
update
(
ntokens
)
self
.
meters
[
'ups'
].
update
(
1.
)
self
.
meters
[
'wpb'
].
update
(
ntokens
)
self
.
meters
[
'bsz'
].
update
(
nsentences
)
if
grad_norm
is
not
None
:
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
# update loss meters for training
if
'loss'
in
agg_logging_output
:
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
# criterions can optionally log the NLL loss too
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
except
OverflowError
as
e
:
self
.
zero_grad
()
print
(
'| WARNING: overflow detected, '
+
str
(
e
))
self
.
clear_buffered_stats
()
return
agg_logging_output
def
_forward
(
self
,
sample
,
eval
=
False
):
def
_forward
(
self
,
sample
,
eval
=
False
):
loss
=
None
loss
=
None
...
@@ -320,7 +337,7 @@ class Trainer(object):
...
@@ -320,7 +337,7 @@ class Trainer(object):
def
dummy_train_step
(
self
,
dummy_batch
):
def
dummy_train_step
(
self
,
dummy_batch
):
"""Dummy training step for warming caching allocator."""
"""Dummy training step for warming caching allocator."""
self
.
train_step
(
dummy_batch
,
update_params
=
False
)
self
.
train_step
(
dummy_batch
,
update_params
=
False
,
dummy_batch
=
True
)
self
.
zero_grad
()
self
.
zero_grad
()
self
.
clear_buffered_stats
()
self
.
clear_buffered_stats
()
...
...
train.py
View file @
9c102784
...
@@ -185,6 +185,7 @@ def get_training_stats(trainer):
...
@@ -185,6 +185,7 @@ def get_training_stats(trainer):
if
trainer
.
get_meter
(
'loss_scale'
)
is
not
None
:
if
trainer
.
get_meter
(
'loss_scale'
)
is
not
None
:
stats
[
'loss_scale'
]
=
'{:.3f}'
.
format
(
trainer
.
get_meter
(
'loss_scale'
).
avg
)
stats
[
'loss_scale'
]
=
'{:.3f}'
.
format
(
trainer
.
get_meter
(
'loss_scale'
).
avg
)
stats
[
'wall'
]
=
round
(
trainer
.
get_meter
(
'wall'
).
elapsed_time
)
stats
[
'wall'
]
=
round
(
trainer
.
get_meter
(
'wall'
).
elapsed_time
)
stats
[
'train_wall'
]
=
round
(
trainer
.
get_meter
(
'train_wall'
).
sum
)
return
stats
return
stats
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment