Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2d27ae08
Commit
2d27ae08
authored
Apr 07, 2018
by
Sergey Edunov
Committed by
Myle Ott
Jun 15, 2018
Browse files
Simulated big batches
parent
60c4081b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
50 deletions
+88
-50
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/trainer.py
fairseq/trainer.py
+78
-47
singleprocess_train.py
singleprocess_train.py
+8
-3
No files found.
fairseq/options.py
View file @
2d27ae08
...
...
@@ -176,6 +176,8 @@ def add_optimization_args(parser):
' dataset'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'sort batches by source length for first N epochs'
)
group
.
add_argument
(
'--update-freq'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'update parameters every N batches'
)
return
group
...
...
fairseq/trainer.py
View file @
2d27ae08
...
...
@@ -9,7 +9,8 @@
Train a network on multiple GPUs.
"""
from
collections
import
OrderedDict
from
collections
import
defaultdict
,
OrderedDict
from
itertools
import
chain
import
math
import
torch
...
...
@@ -55,6 +56,7 @@ class Trainer(object):
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_max_bsz_seen
=
0
self
.
_num_updates
=
0
self
.
_optim_history
=
None
...
...
@@ -86,22 +88,46 @@ class Trainer(object):
return
extra_state
def
train_step
(
self
,
sample
):
def
train_step
(
self
,
sample
,
update_params
=
True
):
"""Do forward, backward and parameter update."""
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
False
)
# forward pass
loss
,
sample_sizes
,
logging_outputs
,
ooms_fwd
=
self
.
_forward
(
sample
)
# forward and backward pass
loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
)
oom_bwd
=
self
.
_backward
(
loss
)
# buffer stats and logging outputs
self
.
_buffered_stats
[
'sample_sizes'
].
append
(
sample_size
)
self
.
_buffered_stats
[
'logging_outputs'
].
append
(
logging_output
)
self
.
_buffered_stats
[
'ooms_fwd'
].
append
(
oom_fwd
)
self
.
_buffered_stats
[
'ooms_bwd'
].
append
(
oom_bwd
)
# update parameters
if
update_params
:
# gather logging outputs from all GPUs
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
ooms_bwd
=
self
.
_buffered_stats
[
'ooms_bwd'
]
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
=
map
(
lambda
l
:
list
(
chain
.
from_iterable
(
l
)),
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
)
))
)
ooms_fwd
=
sum
(
ooms_fwd
)
ooms_bwd
=
sum
(
ooms_bwd
)
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
# backward pass, all-reduce gradients and take an optimization step
grad_norm
,
ooms_bwd
=
self
.
_backward_and_opt
(
loss
,
grad_denom
)
# all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norm
=
self
.
_opt
(
grad_denom
)
# update meters
self
.
meters
[
'wps'
].
update
(
ntokens
)
...
...
@@ -119,7 +145,11 @@ class Trainer(object):
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
self
.
_buffered_stats
.
clear
()
return
agg_logging_output
else
:
return
None
# buffering updates
def
_forward
(
self
,
sample
,
eval
=
False
):
# prepare model and optimizer
...
...
@@ -127,7 +157,6 @@ class Trainer(object):
self
.
model
.
eval
()
else
:
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
loss
=
None
sample_size
=
0
...
...
@@ -152,19 +181,9 @@ class Trainer(object):
else
:
raise
e
# synchronize logging outputs for multi-GPU training
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms
=
zip
(
*
list
(
distributed_utils
.
all_gather_list
((
sample_size
,
logging_output
,
oom
))))
ooms
=
sum
(
ooms
)
else
:
sample_sizes
=
[
sample_size
]
logging_outputs
=
[
logging_output
]
ooms
=
oom
return
loss
,
sample_sizes
,
logging_outputs
,
ooms
return
loss
,
sample_size
,
logging_output
,
oom
def
_backward
_and_opt
(
self
,
loss
,
grad_denom
):
def
_backward
(
self
,
loss
):
oom
=
0
if
loss
is
not
None
:
try
:
...
...
@@ -179,7 +198,9 @@ class Trainer(object):
self
.
optimizer
.
zero_grad
()
else
:
raise
e
return
oom
def
_opt
(
self
,
grad_denom
):
# all-reduce grads and rescale by grad_denom
if
self
.
args
.
distributed_world_size
>
1
:
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
...
...
@@ -197,12 +218,13 @@ class Trainer(object):
# take an optimization step
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
self
.
_num_updates
+=
1
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
return
grad_norm
,
oom
return
grad_norm
def
valid_step
(
self
,
sample
):
"""Do forward pass in evaluation mode."""
...
...
@@ -210,8 +232,17 @@ class Trainer(object):
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
True
)
# forward pass
loss
,
sample_sizes
,
logging_outputs
,
ooms_fwd
=
self
.
_forward
(
sample
,
eval
=
True
)
assert
not
ooms_fwd
,
'Ran out of memory during validation'
_loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
,
eval
=
True
)
assert
not
oom_fwd
,
'Ran out of memory during validation'
# gather logging outputs from all GPUs
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
=
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_size
,
logging_output
)
))
else
:
sample_sizes
=
[
sample_size
]
logging_outputs
=
[
logging_output
]
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
...
...
singleprocess_train.py
View file @
2d27ae08
...
...
@@ -132,6 +132,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
num_shards
=
args
.
distributed_world_size
,
)
progress
=
progress_bar
.
build_progress_bar
(
args
,
itr
,
epoch
,
no_progress_bar
=
'simple'
)
epoch_size
=
len
(
itr
)
itr
=
itertools
.
islice
(
progress
,
batch_offset
,
None
)
# reset training meters
...
...
@@ -143,7 +144,12 @@ def train(args, trainer, dataset, epoch, batch_offset):
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
max_update
=
args
.
max_update
or
math
.
inf
for
i
,
sample
in
enumerate
(
itr
,
start
=
batch_offset
):
log_output
=
trainer
.
train_step
(
sample
)
if
i
<
epoch_size
-
1
and
(
i
+
1
)
%
args
.
update_freq
>
0
:
# buffer updates according to --update-freq
trainer
.
train_step
(
sample
,
update_params
=
False
)
continue
else
:
log_output
=
trainer
.
train_step
(
sample
,
update_params
=
True
)
# log mid-epoch stats
stats
=
get_training_stats
(
trainer
)
...
...
@@ -157,9 +163,8 @@ def train(args, trainer, dataset, epoch, batch_offset):
stats
[
k
]
=
extra_meters
[
k
].
avg
progress
.
log
(
stats
)
# save mid-epoch checkpoints
if
i
==
batch_offset
:
# ignore the first mini-batch in words-per-second calculation
if
i
==
batch_offset
:
trainer
.
get_meter
(
'wps'
).
reset
()
# save mid-epoch checkpoints
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment