Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2d27ae08
"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "105b77fe346e7a1267e8319073a9353a1b45f395"
Commit
2d27ae08
authored
Apr 07, 2018
by
Sergey Edunov
Committed by
Myle Ott
Jun 15, 2018
Browse files
Simulated big batches
parent
60c4081b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
50 deletions
+88
-50
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/trainer.py
fairseq/trainer.py
+78
-47
singleprocess_train.py
singleprocess_train.py
+8
-3
No files found.
fairseq/options.py
View file @
2d27ae08
...
...
@@ -176,6 +176,8 @@ def add_optimization_args(parser):
' dataset'
)
group
.
add_argument
(
'--curriculum'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'sort batches by source length for first N epochs'
)
group
.
add_argument
(
'--update-freq'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'update parameters every N batches'
)
return
group
...
...
fairseq/trainer.py
View file @
2d27ae08
...
...
@@ -9,7 +9,8 @@
Train a network on multiple GPUs.
"""
from
collections
import
OrderedDict
from
collections
import
defaultdict
,
OrderedDict
from
itertools
import
chain
import
math
import
torch
...
...
@@ -55,6 +56,7 @@ class Trainer(object):
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_max_bsz_seen
=
0
self
.
_num_updates
=
0
self
.
_optim_history
=
None
...
...
@@ -86,40 +88,68 @@ class Trainer(object):
return
extra_state
def
train_step
(
self
,
sample
):
def
train_step
(
self
,
sample
,
update_params
=
True
):
"""Do forward, backward and parameter update."""
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
False
)
# forward pass
loss
,
sample_sizes
,
logging_outputs
,
ooms_fwd
=
self
.
_forward
(
sample
)
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
# backward pass, all-reduce gradients and take an optimization step
grad_norm
,
ooms_bwd
=
self
.
_backward_and_opt
(
loss
,
grad_denom
)
# update meters
self
.
meters
[
'wps'
].
update
(
ntokens
)
self
.
meters
[
'ups'
].
update
(
1.
)
self
.
meters
[
'wpb'
].
update
(
ntokens
)
self
.
meters
[
'bsz'
].
update
(
nsentences
)
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
# update loss meters for training
if
'loss'
in
agg_logging_output
:
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
# criterions can optionally log the NLL loss too
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
return
agg_logging_output
# forward and backward pass
loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
)
oom_bwd
=
self
.
_backward
(
loss
)
# buffer stats and logging outputs
self
.
_buffered_stats
[
'sample_sizes'
].
append
(
sample_size
)
self
.
_buffered_stats
[
'logging_outputs'
].
append
(
logging_output
)
self
.
_buffered_stats
[
'ooms_fwd'
].
append
(
oom_fwd
)
self
.
_buffered_stats
[
'ooms_bwd'
].
append
(
oom_bwd
)
# update parameters
if
update_params
:
# gather logging outputs from all GPUs
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
ooms_bwd
=
self
.
_buffered_stats
[
'ooms_bwd'
]
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
=
map
(
lambda
l
:
list
(
chain
.
from_iterable
(
l
)),
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
)
))
)
ooms_fwd
=
sum
(
ooms_fwd
)
ooms_bwd
=
sum
(
ooms_bwd
)
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
# all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norm
=
self
.
_opt
(
grad_denom
)
# update meters
self
.
meters
[
'wps'
].
update
(
ntokens
)
self
.
meters
[
'ups'
].
update
(
1.
)
self
.
meters
[
'wpb'
].
update
(
ntokens
)
self
.
meters
[
'bsz'
].
update
(
nsentences
)
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
# update loss meters for training
if
'loss'
in
agg_logging_output
:
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
# criterions can optionally log the NLL loss too
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
self
.
_buffered_stats
.
clear
()
return
agg_logging_output
else
:
return
None
# buffering updates
def
_forward
(
self
,
sample
,
eval
=
False
):
# prepare model and optimizer
...
...
@@ -127,7 +157,6 @@ class Trainer(object):
self
.
model
.
eval
()
else
:
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
loss
=
None
sample_size
=
0
...
...
@@ -152,19 +181,9 @@ class Trainer(object):
else
:
raise
e
# synchronize logging outputs for multi-GPU training
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms
=
zip
(
*
list
(
distributed_utils
.
all_gather_list
((
sample_size
,
logging_output
,
oom
))))
ooms
=
sum
(
ooms
)
else
:
sample_sizes
=
[
sample_size
]
logging_outputs
=
[
logging_output
]
ooms
=
oom
return
loss
,
sample_sizes
,
logging_outputs
,
ooms
return
loss
,
sample_size
,
logging_output
,
oom
def
_backward
_and_opt
(
self
,
loss
,
grad_denom
):
def
_backward
(
self
,
loss
):
oom
=
0
if
loss
is
not
None
:
try
:
...
...
@@ -179,7 +198,9 @@ class Trainer(object):
self
.
optimizer
.
zero_grad
()
else
:
raise
e
return
oom
def
_opt
(
self
,
grad_denom
):
# all-reduce grads and rescale by grad_denom
if
self
.
args
.
distributed_world_size
>
1
:
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
...
...
@@ -197,12 +218,13 @@ class Trainer(object):
# take an optimization step
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
self
.
_num_updates
+=
1
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
return
grad_norm
,
oom
return
grad_norm
def
valid_step
(
self
,
sample
):
"""Do forward pass in evaluation mode."""
...
...
@@ -210,8 +232,17 @@ class Trainer(object):
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
True
)
# forward pass
loss
,
sample_sizes
,
logging_outputs
,
ooms_fwd
=
self
.
_forward
(
sample
,
eval
=
True
)
assert
not
ooms_fwd
,
'Ran out of memory during validation'
_loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
,
eval
=
True
)
assert
not
oom_fwd
,
'Ran out of memory during validation'
# gather logging outputs from all GPUs
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
=
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_size
,
logging_output
)
))
else
:
sample_sizes
=
[
sample_size
]
logging_outputs
=
[
logging_output
]
# aggregate stats and logging outputs
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
...
...
singleprocess_train.py
View file @
2d27ae08
...
...
@@ -132,6 +132,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
num_shards
=
args
.
distributed_world_size
,
)
progress
=
progress_bar
.
build_progress_bar
(
args
,
itr
,
epoch
,
no_progress_bar
=
'simple'
)
epoch_size
=
len
(
itr
)
itr
=
itertools
.
islice
(
progress
,
batch_offset
,
None
)
# reset training meters
...
...
@@ -143,7 +144,12 @@ def train(args, trainer, dataset, epoch, batch_offset):
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
max_update
=
args
.
max_update
or
math
.
inf
for
i
,
sample
in
enumerate
(
itr
,
start
=
batch_offset
):
log_output
=
trainer
.
train_step
(
sample
)
if
i
<
epoch_size
-
1
and
(
i
+
1
)
%
args
.
update_freq
>
0
:
# buffer updates according to --update-freq
trainer
.
train_step
(
sample
,
update_params
=
False
)
continue
else
:
log_output
=
trainer
.
train_step
(
sample
,
update_params
=
True
)
# log mid-epoch stats
stats
=
get_training_stats
(
trainer
)
...
...
@@ -157,9 +163,8 @@ def train(args, trainer, dataset, epoch, batch_offset):
stats
[
k
]
=
extra_meters
[
k
].
avg
progress
.
log
(
stats
)
#
save mid-epoch checkpoints
#
ignore the first mini-batch in words-per-second calculation
if
i
==
batch_offset
:
# ignore the first mini-batch in words-per-second calculation
trainer
.
get_meter
(
'wps'
).
reset
()
# save mid-epoch checkpoints
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment