Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2d27ae08
Commit
2d27ae08
authored
Apr 07, 2018
by
Sergey Edunov
Committed by
Myle Ott
Jun 15, 2018
Browse files
Simulated big batches
parent
60c4081b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
50 deletions
+88
-50
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/trainer.py
fairseq/trainer.py
+78
-47
singleprocess_train.py
singleprocess_train.py
+8
-3
No files found.
fairseq/options.py
View file @
2d27ae08
...
@@ -176,6 +176,8 @@ def add_optimization_args(parser):
...
@@ -176,6 +176,8 @@ def add_optimization_args(parser):
' dataset'
)
' dataset'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'sort batches by source length for first N epochs'
)
help
=
'sort batches by source length for first N epochs'
)
group
.
add_argument
(
'--update-freq'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'update parameters every N batches'
)
return
group
return
group
...
...
fairseq/trainer.py
View file @
2d27ae08
...
@@ -9,7 +9,8 @@
...
@@ -9,7 +9,8 @@
Train a network on multiple GPUs.
Train a network on multiple GPUs.
"""
"""
from
collections
import
OrderedDict
from
collections
import
defaultdict
,
OrderedDict
from
itertools
import
chain
import
math
import
math
import
torch
import
torch
...
@@ -55,6 +56,7 @@ class Trainer(object):
...
@@ -55,6 +56,7 @@ class Trainer(object):
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_max_bsz_seen
=
0
self
.
_max_bsz_seen
=
0
self
.
_num_updates
=
0
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
_optim_history
=
None
...
@@ -86,40 +88,68 @@ class Trainer(object):
...
@@ -86,40 +88,68 @@ class Trainer(object):
return
extra_state
return
extra_state
def
train_step
(
self
,
sample
):
def
train_step
(
self
,
sample
,
update_params
=
True
):
"""Do forward, backward and parameter update."""
"""Do forward, backward and parameter update."""
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
False
)
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
False
)
# forward pass
# forward and backward pass
loss
,
sample_sizes
,
logging_outputs
,
ooms_fwd
=
self
.
_forward
(
sample
)
loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
)
oom_bwd
=
self
.
_backward
(
loss
)
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
# buffer stats and logging outputs
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
self
.
_buffered_stats
[
'sample_sizes'
].
append
(
sample_size
)
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
self
.
_buffered_stats
[
'logging_outputs'
].
append
(
logging_output
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
self
.
_buffered_stats
[
'ooms_fwd'
].
append
(
oom_fwd
)
self
.
_buffered_stats
[
'ooms_bwd'
].
append
(
oom_bwd
)
# backward pass, all-reduce gradients and take an optimization step
grad_norm
,
ooms_bwd
=
self
.
_backward_and_opt
(
loss
,
grad_denom
)
# update parameters
if
update_params
:
# update meters
# gather logging outputs from all GPUs
self
.
meters
[
'wps'
].
update
(
ntokens
)
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
self
.
meters
[
'ups'
].
update
(
1.
)
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
self
.
meters
[
'wpb'
].
update
(
ntokens
)
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
self
.
meters
[
'bsz'
].
update
(
nsentences
)
ooms_bwd
=
self
.
_buffered_stats
[
'ooms_bwd'
]
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
if
self
.
args
.
distributed_world_size
>
1
:
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
=
map
(
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
lambda
l
:
list
(
chain
.
from_iterable
(
l
)),
zip
(
*
distributed_utils
.
all_gather_list
(
# update loss meters for training
(
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
)
if
'loss'
in
agg_logging_output
:
))
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
)
# criterions can optionally log the NLL loss too
ooms_fwd
=
sum
(
ooms_fwd
)
if
'nll_loss'
in
agg_logging_output
:
ooms_bwd
=
sum
(
ooms_bwd
)
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
# aggregate stats and logging outputs
return
agg_logging_output
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
# all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norm
=
self
.
_opt
(
grad_denom
)
# update meters
self
.
meters
[
'wps'
].
update
(
ntokens
)
self
.
meters
[
'ups'
].
update
(
1.
)
self
.
meters
[
'wpb'
].
update
(
ntokens
)
self
.
meters
[
'bsz'
].
update
(
nsentences
)
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
# update loss meters for training
if
'loss'
in
agg_logging_output
:
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
# criterions can optionally log the NLL loss too
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
self
.
_buffered_stats
.
clear
()
return
agg_logging_output
else
:
return
None
# buffering updates
def
_forward
(
self
,
sample
,
eval
=
False
):
def
_forward
(
self
,
sample
,
eval
=
False
):
# prepare model and optimizer
# prepare model and optimizer
...
@@ -127,7 +157,6 @@ class Trainer(object):
...
@@ -127,7 +157,6 @@ class Trainer(object):
self
.
model
.
eval
()
self
.
model
.
eval
()
else
:
else
:
self
.
model
.
train
()
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
loss
=
None
loss
=
None
sample_size
=
0
sample_size
=
0
...
@@ -152,19 +181,9 @@ class Trainer(object):
...
@@ -152,19 +181,9 @@ class Trainer(object):
else
:
else
:
raise
e
raise
e
# synchronize logging outputs for multi-GPU training
return
loss
,
sample_size
,
logging_output
,
oom
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms
=
zip
(
*
list
(
distributed_utils
.
all_gather_list
((
sample_size
,
logging_output
,
oom
))))
ooms
=
sum
(
ooms
)
else
:
sample_sizes
=
[
sample_size
]
logging_outputs
=
[
logging_output
]
ooms
=
oom
return
loss
,
sample_sizes
,
logging_outputs
,
ooms
def
_backward
_and_opt
(
self
,
loss
,
grad_denom
):
def
_backward
(
self
,
loss
):
oom
=
0
oom
=
0
if
loss
is
not
None
:
if
loss
is
not
None
:
try
:
try
:
...
@@ -179,7 +198,9 @@ class Trainer(object):
...
@@ -179,7 +198,9 @@ class Trainer(object):
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
else
:
else
:
raise
e
raise
e
return
oom
def
_opt
(
self
,
grad_denom
):
# all-reduce grads and rescale by grad_denom
# all-reduce grads and rescale by grad_denom
if
self
.
args
.
distributed_world_size
>
1
:
if
self
.
args
.
distributed_world_size
>
1
:
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
...
@@ -197,12 +218,13 @@ class Trainer(object):
...
@@ -197,12 +218,13 @@ class Trainer(object):
# take an optimization step
# take an optimization step
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
self
.
_num_updates
+=
1
self
.
_num_updates
+=
1
# update learning rate
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
return
grad_norm
,
oom
return
grad_norm
def
valid_step
(
self
,
sample
):
def
valid_step
(
self
,
sample
):
"""Do forward pass in evaluation mode."""
"""Do forward pass in evaluation mode."""
...
@@ -210,8 +232,17 @@ class Trainer(object):
...
@@ -210,8 +232,17 @@ class Trainer(object):
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
True
)
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
True
)
# forward pass
# forward pass
loss
,
sample_sizes
,
logging_outputs
,
ooms_fwd
=
self
.
_forward
(
sample
,
eval
=
True
)
_loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
,
eval
=
True
)
assert
not
ooms_fwd
,
'Ran out of memory during validation'
assert
not
oom_fwd
,
'Ran out of memory during validation'
# gather logging outputs from all GPUs
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
=
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_size
,
logging_output
)
))
else
:
sample_sizes
=
[
sample_size
]
logging_outputs
=
[
logging_output
]
# aggregate stats and logging outputs
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
...
...
singleprocess_train.py
View file @
2d27ae08
...
@@ -132,6 +132,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
...
@@ -132,6 +132,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
num_shards
=
args
.
distributed_world_size
,
num_shards
=
args
.
distributed_world_size
,
)
)
progress
=
progress_bar
.
build_progress_bar
(
args
,
itr
,
epoch
,
no_progress_bar
=
'simple'
)
progress
=
progress_bar
.
build_progress_bar
(
args
,
itr
,
epoch
,
no_progress_bar
=
'simple'
)
epoch_size
=
len
(
itr
)
itr
=
itertools
.
islice
(
progress
,
batch_offset
,
None
)
itr
=
itertools
.
islice
(
progress
,
batch_offset
,
None
)
# reset training meters
# reset training meters
...
@@ -143,7 +144,12 @@ def train(args, trainer, dataset, epoch, batch_offset):
...
@@ -143,7 +144,12 @@ def train(args, trainer, dataset, epoch, batch_offset):
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
max_update
=
args
.
max_update
or
math
.
inf
max_update
=
args
.
max_update
or
math
.
inf
for
i
,
sample
in
enumerate
(
itr
,
start
=
batch_offset
):
for
i
,
sample
in
enumerate
(
itr
,
start
=
batch_offset
):
log_output
=
trainer
.
train_step
(
sample
)
if
i
<
epoch_size
-
1
and
(
i
+
1
)
%
args
.
update_freq
>
0
:
# buffer updates according to --update-freq
trainer
.
train_step
(
sample
,
update_params
=
False
)
continue
else
:
log_output
=
trainer
.
train_step
(
sample
,
update_params
=
True
)
# log mid-epoch stats
# log mid-epoch stats
stats
=
get_training_stats
(
trainer
)
stats
=
get_training_stats
(
trainer
)
...
@@ -157,9 +163,8 @@ def train(args, trainer, dataset, epoch, batch_offset):
...
@@ -157,9 +163,8 @@ def train(args, trainer, dataset, epoch, batch_offset):
stats
[
k
]
=
extra_meters
[
k
].
avg
stats
[
k
]
=
extra_meters
[
k
].
avg
progress
.
log
(
stats
)
progress
.
log
(
stats
)
#
save mid-epoch checkpoints
#
ignore the first mini-batch in words-per-second calculation
if
i
==
batch_offset
:
if
i
==
batch_offset
:
# ignore the first mini-batch in words-per-second calculation
trainer
.
get_meter
(
'wps'
).
reset
()
trainer
.
get_meter
(
'wps'
).
reset
()
# save mid-epoch checkpoints
# save mid-epoch checkpoints
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment