Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
7ee1d284
Commit
7ee1d284
authored
Apr 10, 2018
by
Myle Ott
Browse files
Add FP16 support
parent
73a87327
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
309 additions
and
131 deletions
+309
-131
fairseq/data.py
fairseq/data.py
+20
-0
fairseq/distributed_utils.py
fairseq/distributed_utils.py
+0
-52
fairseq/fp16_trainer.py
fairseq/fp16_trainer.py
+141
-0
fairseq/models/fairseq_decoder.py
fairseq/models/fairseq_decoder.py
+1
-1
fairseq/options.py
fairseq/options.py
+3
-0
fairseq/trainer.py
fairseq/trainer.py
+101
-62
fairseq/utils.py
fairseq/utils.py
+27
-11
scripts/average_checkpoints.py
scripts/average_checkpoints.py
+1
-1
singleprocess_train.py
singleprocess_train.py
+15
-4
No files found.
fairseq/data.py
View file @
7ee1d284
...
@@ -12,7 +12,9 @@ import math
...
@@ -12,7 +12,9 @@ import math
import
numbers
import
numbers
import
numpy
as
np
import
numpy
as
np
import
os
import
os
import
torch
import
torch
from
torch.autograd
import
Variable
import
torch.utils.data
import
torch.utils.data
from
fairseq.dictionary
import
Dictionary
from
fairseq.dictionary
import
Dictionary
...
@@ -435,3 +437,21 @@ def numpy_seed(seed):
...
@@ -435,3 +437,21 @@ def numpy_seed(seed):
yield
yield
finally
:
finally
:
np
.
random
.
set_state
(
state
)
np
.
random
.
set_state
(
state
)
def
get_dummy_batch
(
ntokens
,
src_dict
,
dst_dict
,
src_len
=
128
,
tgt_len
=
128
):
bsz
=
int
(
ntokens
/
max
(
src_len
,
tgt_len
))
bsz
=
(
bsz
//
8
)
*
8
assert
src_dict
.
pad
()
==
dst_dict
.
pad
()
pad_idx
=
src_dict
.
pad
()
src_vocab
,
dst_vocab
=
len
(
src_dict
),
len
(
dst_dict
)
dummy_batch
=
{}
dummy_batch
[
'id'
]
=
Variable
(
torch
.
arange
(
bsz
).
long
().
cuda
())
dummy_batch
[
'ntokens'
]
=
tgt_len
*
bsz
dummy_batch
[
'target'
]
=
Variable
(
torch
.
Tensor
(
bsz
,
tgt_len
).
uniform_
(
pad_idx
+
1
,
dst_vocab
-
1
).
long
().
cuda
())
input
=
{}
input
[
'prev_output_tokens'
]
=
Variable
(
dummy_batch
[
'target'
].
data
.
clone
())
input
[
'src_lengths'
]
=
Variable
(
torch
.
LongTensor
(
bsz
).
fill_
(
src_len
).
cuda
())
input
[
'src_tokens'
]
=
Variable
(
torch
.
Tensor
(
bsz
,
src_len
).
uniform_
(
pad_idx
+
1
,
src_vocab
-
1
).
long
().
cuda
())
dummy_batch
[
'net_input'
]
=
input
return
dummy_batch
fairseq/distributed_utils.py
View file @
7ee1d284
...
@@ -53,58 +53,6 @@ def suppress_output():
...
@@ -53,58 +53,6 @@ def suppress_output():
__builtin__
.
print
=
print
__builtin__
.
print
=
print
def
all_reduce_and_rescale_tensors
(
tensors
,
rescale_denom
,
buffer_size
=
10485760
):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size is in bytes, determine equiv. # of elements based on data type
buffer_t
=
tensors
[
0
].
new
(
math
.
ceil
(
buffer_size
/
tensors
[
0
].
element_size
())).
zero_
()
buffer
=
[]
def
all_reduce_buffer
():
# copy tensors into buffer_t
offset
=
0
for
t
in
buffer
:
numel
=
t
.
numel
()
buffer_t
[
offset
:
offset
+
numel
].
copy_
(
t
.
view
(
-
1
))
offset
+=
numel
# all-reduce and rescale
torch
.
distributed
.
all_reduce
(
buffer_t
[:
offset
])
buffer_t
.
div_
(
rescale_denom
)
# copy all-reduced buffer back into tensors
offset
=
0
for
t
in
buffer
:
numel
=
t
.
numel
()
t
.
view
(
-
1
).
copy_
(
buffer_t
[
offset
:
offset
+
numel
])
offset
+=
numel
filled
=
0
for
t
in
tensors
:
sz
=
t
.
numel
()
*
t
.
element_size
()
if
sz
>
buffer_size
:
# tensor is bigger than buffer, all-reduce and rescale directly
torch
.
distributed
.
all_reduce
(
t
)
t
.
div_
(
rescale_denom
)
elif
filled
+
sz
>
buffer_size
:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer
()
buffer
=
[
t
]
filled
=
sz
else
:
# add tensor to buffer
buffer
.
append
(
t
)
filled
+=
sz
if
len
(
buffer
)
>
0
:
all_reduce_buffer
()
def
all_gather_list
(
data
,
max_size
=
4096
):
def
all_gather_list
(
data
,
max_size
=
4096
):
"""Gathers arbitrary data from all nodes into a list."""
"""Gathers arbitrary data from all nodes into a list."""
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
...
...
fairseq/fp16_trainer.py
0 → 100644
View file @
7ee1d284
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import
math
import
torch
from
fairseq
import
optim
from
fairseq.meters
import
AverageMeter
from
fairseq.optim
import
lr_scheduler
from
fairseq.trainer
import
Trainer
class
DynamicLossScaler
:
def
__init__
(
self
,
init_scale
=
2.
**
15
,
scale_factor
=
2.
,
scale_window
=
2000
):
self
.
loss_scale
=
init_scale
self
.
scale_factor
=
scale_factor
self
.
scale_window
=
scale_window
self
.
_iter
=
0
self
.
_last_overflow_iter
=
-
1
def
update_scale
(
self
,
overflow
):
if
overflow
:
self
.
loss_scale
/=
self
.
scale_factor
self
.
_last_overflow_iter
=
self
.
_iter
elif
(
self
.
_iter
-
self
.
_last_overflow_iter
)
%
self
.
scale_window
==
0
:
self
.
loss_scale
*=
self
.
scale_factor
self
.
_iter
+=
1
@
staticmethod
def
has_overflow
(
grad_norm
):
# detect inf and nan
if
grad_norm
==
float
(
'inf'
)
or
grad_norm
!=
grad_norm
:
return
True
return
False
class
FP16Trainer
(
Trainer
):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def
__init__
(
self
,
args
,
model
,
criterion
):
super
().
__init__
(
args
,
model
,
criterion
)
# convert model to FP16 (but keep criterion FP32)
self
.
model
.
half
()
# dynamically scale loss to reduce overflow
self
.
scaler
=
DynamicLossScaler
(
init_scale
=
2.
**
7
)
self
.
meters
[
'loss_scale'
]
=
AverageMeter
()
def
_build_optimizer
(
self
):
# create FP32 copy of parameters and grads
params
=
[
p
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
total_param_size
=
sum
(
p
.
data
.
numel
()
for
p
in
params
)
self
.
fp32_params
=
params
[
0
].
new
(
0
).
float
().
new
(
total_param_size
)
offset
=
0
for
p
in
params
:
numel
=
p
.
data
.
numel
()
self
.
fp32_params
[
offset
:
offset
+
numel
].
copy_
(
p
.
data
.
view
(
-
1
))
offset
+=
numel
self
.
fp32_params
=
torch
.
nn
.
Parameter
(
self
.
fp32_params
)
self
.
fp32_params
.
grad
=
self
.
fp32_params
.
data
.
new
(
total_param_size
)
# create optimizer using the copied FP32 params
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
[
self
.
fp32_params
])
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
extra_state
[
'loss_scale'
]
=
self
.
scaler
.
loss_scale
super
().
save_checkpoint
(
filename
,
extra_state
)
def
load_checkpoint
(
self
,
filename
):
"""Load all training state from a checkpoint file."""
extra_state
=
super
().
load_checkpoint
(
filename
)
if
extra_state
is
not
None
and
'loss_scale'
in
extra_state
:
self
.
scaler
.
loss_scale
=
extra_state
[
'loss_scale'
]
return
extra_state
def
zero_grad
(
self
):
# zero both the FP16 and FP32 grads
self
.
model
.
zero_grad
()
# FP16
self
.
optimizer
.
zero_grad
()
# FP32
def
_backward
(
self
,
loss
):
self
.
meters
[
'loss_scale'
].
reset
()
self
.
meters
[
'loss_scale'
].
update
(
self
.
scaler
.
loss_scale
)
if
loss
is
not
None
:
# dynamically rescale loss to stay in FP16 range
loss
=
loss
*
self
.
scaler
.
loss_scale
return
super
().
_backward
(
loss
)
def
_all_reduce_and_rescale
(
self
,
grad_denom
):
# undo effect of dynamic loss scaling on gradients
grad_denom
*=
self
.
scaler
.
loss_scale
# all-reduce and rescale gradients
grad_norm
=
super
().
_all_reduce_and_rescale
(
grad_denom
)
# detect overflow and adjust loss scale
overflow
=
DynamicLossScaler
.
has_overflow
(
grad_norm
)
self
.
scaler
.
update_scale
(
overflow
)
if
overflow
:
raise
OverflowError
(
'setting loss scale to: '
+
str
(
self
.
scaler
.
loss_scale
))
return
grad_norm
def
_get_flat_grads
(
self
,
out
=
None
):
if
out
is
None
:
out
=
self
.
fp32_params
.
grad
return
super
().
_get_flat_grads
(
out
)
def
_set_flat_grads
(
self
,
new_grads
):
# no-op
assert
new_grads
.
data_ptr
()
==
self
.
fp32_params
.
grad
.
data
.
data_ptr
()
def
_opt
(
self
):
# take an optimization step using the FP32 params and grads
super
().
_opt
()
# copy FP32 params back into FP16 model
offset
=
0
for
p
in
self
.
model
.
parameters
():
if
not
p
.
requires_grad
:
continue
numel
=
p
.
data
.
numel
()
p
.
data
.
copy_
(
self
.
fp32_params
.
data
[
offset
:
offset
+
numel
].
view_as
(
p
.
data
))
offset
+=
numel
fairseq/models/fairseq_decoder.py
View file @
7ee1d284
...
@@ -21,7 +21,7 @@ class FairseqDecoder(nn.Module):
...
@@ -21,7 +21,7 @@ class FairseqDecoder(nn.Module):
def
get_normalized_probs
(
self
,
net_output
,
log_probs
):
def
get_normalized_probs
(
self
,
net_output
,
log_probs
):
"""Get normalized probabilities (or log probs) from a net's output."""
"""Get normalized probabilities (or log probs) from a net's output."""
logits
=
net_output
[
0
]
logits
=
net_output
[
0
]
.
float
()
if
log_probs
:
if
log_probs
:
return
F
.
log_softmax
(
logits
,
dim
=-
1
)
return
F
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
else
:
...
...
fairseq/options.py
View file @
7ee1d284
...
@@ -155,6 +155,9 @@ def add_optimization_args(parser):
...
@@ -155,6 +155,9 @@ def add_optimization_args(parser):
' (default is to normalize by number of tokens)'
)
' (default is to normalize by number of tokens)'
)
group
.
add_argument
(
'--update-freq'
,
default
=
'1'
,
metavar
=
'N'
,
group
.
add_argument
(
'--update-freq'
,
default
=
'1'
,
metavar
=
'N'
,
help
=
'update parameters every N_i batches, when in epoch i'
)
help
=
'update parameters every N_i batches, when in epoch i'
)
has_tensor_cores
=
torch
.
cuda
.
device_count
()
>
0
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
7
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
default
=
has_tensor_cores
,
help
=
'use FP16 during training'
)
# Optimizer definitions can be found under fairseq/optim/
# Optimizer definitions can be found under fairseq/optim/
group
.
add_argument
(
'--optimizer'
,
default
=
'nag'
,
metavar
=
'OPT'
,
group
.
add_argument
(
'--optimizer'
,
default
=
'nag'
,
metavar
=
'OPT'
,
...
...
fairseq/trainer.py
View file @
7ee1d284
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
"""
"""
Train a network
on
multiple GPUs.
Train a network
across
multiple GPUs.
"""
"""
from
collections
import
defaultdict
,
OrderedDict
from
collections
import
defaultdict
,
OrderedDict
...
@@ -20,11 +20,11 @@ from fairseq.optim import lr_scheduler
...
@@ -20,11 +20,11 @@ from fairseq.optim import lr_scheduler
class
Trainer
(
object
):
class
Trainer
(
object
):
"""Main class for
multi-GPU
training.
"""Main class for
data parallel
training.
Each GPU has a full copy of the model and is assigned to its own Python
This class supports data parallel training, where multiple workers each
process. Gradients are accumulated with torch.distributed.all_reduce and all
have a full model replica and gradients are accumulated synchronously via
model replicas are updated synchronously after each batch
.
torch.distributed.all_reduce
.
"""
"""
def
__init__
(
self
,
args
,
model
,
criterion
):
def
__init__
(
self
,
args
,
model
,
criterion
):
...
@@ -39,8 +39,7 @@ class Trainer(object):
...
@@ -39,8 +39,7 @@ class Trainer(object):
self
.
criterion
=
criterion
.
cuda
()
self
.
criterion
=
criterion
.
cuda
()
# initialize optimizer and LR scheduler
# initialize optimizer and LR scheduler
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
self
.
model
.
parameters
())
self
.
_build_optimizer
()
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
# initialize meters
# initialize meters
self
.
meters
=
OrderedDict
()
self
.
meters
=
OrderedDict
()
...
@@ -55,12 +54,17 @@ class Trainer(object):
...
@@ -55,12 +54,17 @@ class Trainer(object):
self
.
meters
[
'gnorm'
]
=
AverageMeter
()
# gradient norm
self
.
meters
[
'gnorm'
]
=
AverageMeter
()
# gradient norm
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'clip'
]
=
AverageMeter
()
# % of updates clipped
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
meters
[
'oom'
]
=
AverageMeter
()
# out of memory
self
.
meters
[
'wall'
]
=
TimeMeter
()
# wall time in seconds
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_
max_bsz_seen
=
0
self
.
_
flat_grads
=
None
self
.
_num_updates
=
0
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
_optim_history
=
None
def
_build_optimizer
(
self
):
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
self
.
model
.
parameters
())
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
def
save_checkpoint
(
self
,
filename
,
extra_state
):
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
"""Save all training state in a checkpoint file."""
if
distributed_utils
.
is_master
(
self
.
args
):
# only save one checkpoint
if
distributed_utils
.
is_master
(
self
.
args
):
# only save one checkpoint
...
@@ -69,13 +73,12 @@ class Trainer(object):
...
@@ -69,13 +73,12 @@ class Trainer(object):
def
load_checkpoint
(
self
,
filename
):
def
load_checkpoint
(
self
,
filename
):
"""Load all training state from a checkpoint file."""
"""Load all training state from a checkpoint file."""
extra_state
,
self
.
_optim_history
,
last_optim_state
=
utils
.
load_model_state
(
extra_state
,
self
.
_optim_history
,
last_optim_state
=
\
filename
,
self
.
model
,
cuda_device
=
torch
.
cuda
.
current_device
()
)
utils
.
load_model_state
(
filename
,
self
.
model
)
if
last_optim_state
is
not
None
:
if
last_optim_state
is
not
None
:
# rebuild optimizer after loading model, since params may have changed
# rebuild optimizer after loading model, since params may have changed
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
self
.
model
.
parameters
())
self
.
_build_optimizer
()
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
# only reload optimizer and lr_scheduler if they match
# only reload optimizer and lr_scheduler if they match
last_optim
=
self
.
_optim_history
[
-
1
]
last_optim
=
self
.
_optim_history
[
-
1
]
...
@@ -105,7 +108,7 @@ class Trainer(object):
...
@@ -105,7 +108,7 @@ class Trainer(object):
# update parameters
# update parameters
if
update_params
:
if
update_params
:
# gather logging outputs from all
GPU
s
# gather logging outputs from all
replica
s
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
...
@@ -124,28 +127,34 @@ class Trainer(object):
...
@@ -124,28 +127,34 @@ class Trainer(object):
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
nsentences
=
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
agg_logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
# all-reduce gradients and take an optimization step
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
grad_norm
=
self
.
_opt
(
grad_denom
)
try
:
# update meters
# all-reduce and rescale gradients, then take an optimization step
self
.
meters
[
'wps'
].
update
(
ntokens
)
grad_norm
=
self
.
_all_reduce_and_rescale
(
grad_denom
)
self
.
meters
[
'ups'
].
update
(
1.
)
self
.
_opt
()
self
.
meters
[
'wpb'
].
update
(
ntokens
)
self
.
meters
[
'bsz'
].
update
(
nsentences
)
# update meters
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
self
.
meters
[
'wps'
].
update
(
ntokens
)
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
self
.
meters
[
'ups'
].
update
(
1.
)
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
self
.
meters
[
'wpb'
].
update
(
ntokens
)
self
.
meters
[
'bsz'
].
update
(
nsentences
)
# update loss meters for training
if
grad_norm
is
not
None
:
if
'loss'
in
agg_logging_output
:
self
.
meters
[
'gnorm'
].
update
(
grad_norm
)
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
self
.
meters
[
'clip'
].
update
(
1.
if
grad_norm
>
self
.
args
.
clip_norm
else
0.
)
# criterions can optionally log the NLL loss too
self
.
meters
[
'oom'
].
update
(
ooms_fwd
+
ooms_bwd
)
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
# update loss meters for training
if
'loss'
in
agg_logging_output
:
self
.
_buffered_stats
.
clear
()
self
.
meters
[
'train_loss'
].
update
(
agg_logging_output
[
'loss'
],
grad_denom
)
# criterions can optionally log the NLL loss too
if
'nll_loss'
in
agg_logging_output
:
self
.
meters
[
'train_nll_loss'
].
update
(
agg_logging_output
[
'nll_loss'
],
ntokens
)
except
OverflowError
as
e
:
self
.
zero_grad
()
print
(
'| WARNING: overflow detected, '
+
str
(
e
))
self
.
clear_buffered_stats
()
return
agg_logging_output
return
agg_logging_output
else
:
else
:
...
@@ -157,7 +166,6 @@ class Trainer(object):
...
@@ -157,7 +166,6 @@ class Trainer(object):
self
.
model
.
eval
()
self
.
model
.
eval
()
else
:
else
:
self
.
model
.
train
()
self
.
model
.
train
()
loss
=
None
loss
=
None
sample_size
=
0
sample_size
=
0
logging_output
=
{
logging_output
=
{
...
@@ -176,11 +184,8 @@ class Trainer(object):
...
@@ -176,11 +184,8 @@ class Trainer(object):
print
(
'| WARNING: ran out of memory, skipping batch'
)
print
(
'| WARNING: ran out of memory, skipping batch'
)
oom
=
1
oom
=
1
loss
=
None
loss
=
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
torch
.
cuda
.
empty_cache
()
else
:
else
:
raise
e
raise
e
return
loss
,
sample_size
,
logging_output
,
oom
return
loss
,
sample_size
,
logging_output
,
oom
def
_backward
(
self
,
loss
):
def
_backward
(
self
,
loss
):
...
@@ -193,39 +198,66 @@ class Trainer(object):
...
@@ -193,39 +198,66 @@ class Trainer(object):
if
'out of memory'
in
str
(
e
):
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory, skipping batch'
)
print
(
'| WARNING: ran out of memory, skipping batch'
)
oom
=
1
oom
=
1
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
self
.
zero_grad
()
torch
.
cuda
.
empty_cache
()
self
.
optimizer
.
zero_grad
()
else
:
else
:
raise
e
raise
e
return
oom
return
oom
def
_opt
(
self
,
grad_denom
):
def
_all_reduce_and_rescale
(
self
,
grad_denom
):
# all-reduce grads and rescale by grad_denom
# flatten grads into a single buffer and all-reduce
flat_grads
=
self
.
_flat_grads
=
self
.
_get_flat_grads
(
self
.
_flat_grads
)
if
self
.
args
.
distributed_world_size
>
1
:
if
self
.
args
.
distributed_world_size
>
1
:
grads
=
[
p
.
grad
.
data
for
p
in
self
.
model
.
parameters
()
if
p
.
requires_grad
]
torch
.
distributed
.
all_reduce
(
flat_grads
)
distributed_utils
.
all_reduce_and_rescale_tensors
(
grads
,
grad_denom
)
else
:
for
p
in
self
.
model
.
parameters
():
if
p
.
requires_grad
:
p
.
grad
.
data
.
div_
(
grad_denom
)
# clip grads
# rescale and clip gradients
if
self
.
args
.
clip_norm
>
0
:
flat_grads
.
div_
(
grad_denom
)
grad_norm
=
utils
.
item
(
torch
.
nn
.
utils
.
clip_grad_norm
(
self
.
model
.
parameters
(),
self
.
args
.
clip_norm
))
grad_norm
=
utils
.
clip_grad_norm_
(
flat_grads
,
self
.
args
.
clip_norm
)
else
:
grad_norm
=
math
.
sqrt
(
sum
(
p
.
grad
.
data
.
norm
()
**
2
for
p
in
self
.
model
.
parameters
()))
# copy grads back into model parameters
self
.
_set_flat_grads
(
flat_grads
)
return
grad_norm
def
_get_grads
(
self
):
grads
=
[]
for
name
,
p
in
self
.
model
.
named_parameters
():
if
not
p
.
requires_grad
:
continue
if
p
.
grad
is
None
:
raise
RuntimeError
(
'Model parameter did not receive gradient: '
+
name
+
'. '
'Use the param in the forward pass or set requires_grad=False'
)
grads
.
append
(
p
.
grad
.
data
)
return
grads
def
_get_flat_grads
(
self
,
out
=
None
):
grads
=
self
.
_get_grads
()
if
out
is
None
:
grads_size
=
sum
(
g
.
numel
()
for
g
in
grads
)
out
=
grads
[
0
].
new
(
grads_size
).
zero_
()
offset
=
0
for
g
in
grads
:
numel
=
g
.
numel
()
out
[
offset
:
offset
+
numel
].
copy_
(
g
.
view
(
-
1
))
offset
+=
numel
return
out
[:
offset
]
def
_set_flat_grads
(
self
,
new_grads
):
grads
=
self
.
_get_grads
()
offset
=
0
for
g
in
grads
:
numel
=
g
.
numel
()
g
.
copy_
(
new_grads
[
offset
:
offset
+
numel
].
view_as
(
g
))
offset
+=
numel
def
_opt
(
self
):
# take an optimization step
# take an optimization step
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
self
.
zero_grad
()
self
.
_num_updates
+=
1
self
.
_num_updates
+=
1
# update learning rate
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
return
grad_norm
def
valid_step
(
self
,
sample
):
def
valid_step
(
self
,
sample
):
"""Do forward pass in evaluation mode."""
"""Do forward pass in evaluation mode."""
...
@@ -258,6 +290,18 @@ class Trainer(object):
...
@@ -258,6 +290,18 @@ class Trainer(object):
return
agg_logging_output
return
agg_logging_output
def
dummy_train_step
(
self
,
dummy_batch
):
"""Dummy training step for warming caching allocator."""
self
.
train_step
(
dummy_batch
,
update_params
=
False
)
self
.
zero_grad
()
self
.
clear_buffered_stats
()
def
zero_grad
(
self
):
self
.
optimizer
.
zero_grad
()
def
clear_buffered_stats
(
self
):
self
.
_buffered_stats
.
clear
()
def
lr_step
(
self
,
epoch
,
val_loss
=
None
):
def
lr_step
(
self
,
epoch
,
val_loss
=
None
):
"""Adjust the learning rate based on the validation loss."""
"""Adjust the learning rate based on the validation loss."""
return
self
.
lr_scheduler
.
step
(
epoch
,
val_loss
)
return
self
.
lr_scheduler
.
step
(
epoch
,
val_loss
)
...
@@ -283,9 +327,4 @@ class Trainer(object):
...
@@ -283,9 +327,4 @@ class Trainer(object):
def
_prepare_sample
(
self
,
sample
,
volatile
):
def
_prepare_sample
(
self
,
sample
,
volatile
):
if
sample
is
None
or
len
(
sample
)
==
0
:
if
sample
is
None
or
len
(
sample
)
==
0
:
return
None
return
None
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
# clear the caching allocator if this is the largest sample we've seen
if
sample
[
'target'
].
size
(
0
)
>
self
.
_max_bsz_seen
:
self
.
_max_bsz_seen
=
sample
[
'target'
].
size
(
0
)
torch
.
cuda
.
empty_cache
()
return
utils
.
make_variable
(
sample
,
volatile
=
volatile
,
cuda
=
True
)
return
utils
.
make_variable
(
sample
,
volatile
=
volatile
,
cuda
=
True
)
fairseq/utils.py
View file @
7ee1d284
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
from
collections
import
defaultdict
from
collections
import
defaultdict
,
OrderedDict
import
contextlib
import
contextlib
import
logging
import
logging
import
os
import
os
...
@@ -25,6 +25,20 @@ def torch_persistent_save(*args, **kwargs):
...
@@ -25,6 +25,20 @@ def torch_persistent_save(*args, **kwargs):
logging
.
error
(
traceback
.
format_exc
())
logging
.
error
(
traceback
.
format_exc
())
def
convert_state_dict_type
(
state_dict
,
ttype
=
torch
.
FloatTensor
):
if
isinstance
(
state_dict
,
dict
):
cpu_dict
=
OrderedDict
()
for
k
,
v
in
state_dict
.
items
():
cpu_dict
[
k
]
=
convert_state_dict_type
(
v
)
return
cpu_dict
elif
isinstance
(
state_dict
,
list
):
return
[
convert_state_dict_type
(
v
)
for
v
in
state_dict
]
elif
torch
.
is_tensor
(
state_dict
):
return
state_dict
.
type
(
ttype
)
else
:
return
state_dict
def
save_state
(
filename
,
args
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
def
save_state
(
filename
,
args
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
num_updates
,
optim_history
=
None
,
extra_state
=
None
):
num_updates
,
optim_history
=
None
,
extra_state
=
None
):
if
optim_history
is
None
:
if
optim_history
is
None
:
...
@@ -33,7 +47,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
...
@@ -33,7 +47,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state
=
{}
extra_state
=
{}
state_dict
=
{
state_dict
=
{
'args'
:
args
,
'args'
:
args
,
'model'
:
model
.
state_dict
(),
'model'
:
convert_state_dict_type
(
model
.
state_dict
()
)
,
'optimizer_history'
:
optim_history
+
[
'optimizer_history'
:
optim_history
+
[
{
{
'criterion_name'
:
criterion
.
__class__
.
__name__
,
'criterion_name'
:
criterion
.
__class__
.
__name__
,
...
@@ -42,22 +56,16 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
...
@@ -42,22 +56,16 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
'num_updates'
:
num_updates
,
'num_updates'
:
num_updates
,
}
}
],
],
'last_optimizer_state'
:
optimizer
.
state_dict
(),
'last_optimizer_state'
:
convert_state_dict_type
(
optimizer
.
state_dict
()
)
,
'extra_state'
:
extra_state
,
'extra_state'
:
extra_state
,
}
}
torch_persistent_save
(
state_dict
,
filename
)
torch_persistent_save
(
state_dict
,
filename
)
def
load_model_state
(
filename
,
model
,
cuda_device
=
None
):
def
load_model_state
(
filename
,
model
):
if
not
os
.
path
.
exists
(
filename
):
if
not
os
.
path
.
exists
(
filename
):
return
None
,
[],
None
return
None
,
[],
None
if
cuda_device
is
None
:
state
=
torch
.
load
(
filename
)
state
=
torch
.
load
(
filename
)
else
:
state
=
torch
.
load
(
filename
,
map_location
=
lambda
s
,
l
:
default_restore_location
(
s
,
'cuda:{}'
.
format
(
cuda_device
))
)
state
=
_upgrade_state_dict
(
state
)
state
=
_upgrade_state_dict
(
state
)
state
[
'model'
]
=
model
.
upgrade_state_dict
(
state
[
'model'
])
state
[
'model'
]
=
model
.
upgrade_state_dict
(
state
[
'model'
])
...
@@ -377,6 +385,14 @@ def item(tensor):
...
@@ -377,6 +385,14 @@ def item(tensor):
return
tensor
return
tensor
def
clip_grad_norm_
(
tensor
,
max_norm
):
grad_norm
=
item
(
torch
.
norm
(
tensor
))
if
grad_norm
>
max_norm
>
0
:
clip_coef
=
max_norm
/
(
grad_norm
+
1e-6
)
tensor
.
mul_
(
clip_coef
)
return
grad_norm
def
fill_with_neg_inf
(
t
):
def
fill_with_neg_inf
(
t
):
"""FP16-compatible function that fills a tensor with -inf."""
"""FP16-compatible function that fills a tensor with -inf."""
return
t
.
float
().
fill_
(
float
(
'-inf'
)).
type_as
(
t
)
return
t
.
float
().
fill_
(
float
(
'-inf'
)).
type_as
(
t
)
scripts/average_checkpoints.py
View file @
7ee1d284
...
@@ -44,7 +44,7 @@ def average_checkpoints(inputs):
...
@@ -44,7 +44,7 @@ def average_checkpoints(inputs):
for
k
in
params_keys
:
for
k
in
params_keys
:
if
k
not
in
params_dict
:
if
k
not
in
params_dict
:
params_dict
[
k
]
=
[]
params_dict
[
k
]
=
[]
params_dict
[
k
].
append
(
model_params
[
k
])
params_dict
[
k
].
append
(
model_params
[
k
]
.
float
()
)
averaged_params
=
collections
.
OrderedDict
()
averaged_params
=
collections
.
OrderedDict
()
# v should be a list of torch Tensor.
# v should be a list of torch Tensor.
...
...
singleprocess_train.py
View file @
7ee1d284
...
@@ -13,8 +13,9 @@ import math
...
@@ -13,8 +13,9 @@ import math
import
torch
import
torch
from
fairseq
import
criterions
,
data
,
models
,
options
,
progress_bar
from
fairseq
import
criterions
,
data
,
models
,
options
,
progress_bar
from
fairseq.
met
er
s
import
AverageMeter
,
StopwatchMet
er
from
fairseq.
fp16_train
er
import
FP16Train
er
from
fairseq.trainer
import
Trainer
from
fairseq.trainer
import
Trainer
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
def
main
(
args
):
def
main
(
args
):
...
@@ -48,7 +49,10 @@ def main(args):
...
@@ -48,7 +49,10 @@ def main(args):
print
(
'| num. model params: {}'
.
format
(
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())))
print
(
'| num. model params: {}'
.
format
(
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())))
# Build trainer
# Build trainer
trainer
=
Trainer
(
args
,
model
,
criterion
)
if
args
.
fp16
:
trainer
=
FP16Trainer
(
args
,
model
,
criterion
)
else
:
trainer
=
Trainer
(
args
,
model
,
criterion
)
print
(
'| training on {} GPUs'
.
format
(
args
.
distributed_world_size
))
print
(
'| training on {} GPUs'
.
format
(
args
.
distributed_world_size
))
print
(
'| max tokens per GPU = {} and max sentences per GPU = {}'
.
format
(
print
(
'| max tokens per GPU = {} and max sentences per GPU = {}'
.
format
(
args
.
max_tokens
,
args
.
max_tokens
,
...
@@ -84,6 +88,10 @@ def main(args):
...
@@ -84,6 +88,10 @@ def main(args):
_
=
next
(
train_dataloader
)
_
=
next
(
train_dataloader
)
epoch
+=
1
epoch
+=
1
# Send a dummy batch to warm the caching allocator
dummy_batch
=
data
.
get_dummy_batch
(
args
.
max_tokens
,
dataset
.
src_dict
,
dataset
.
dst_dict
)
trainer
.
dummy_train_step
(
dummy_batch
)
# Train until the learning rate gets too small
# Train until the learning rate gets too small
max_epoch
=
args
.
max_epoch
or
math
.
inf
max_epoch
=
args
.
max_epoch
or
math
.
inf
max_update
=
args
.
max_update
or
math
.
inf
max_update
=
args
.
max_update
or
math
.
inf
...
@@ -153,7 +161,7 @@ def train(args, trainer, itr, epoch):
...
@@ -153,7 +161,7 @@ def train(args, trainer, itr, epoch):
# log mid-epoch stats
# log mid-epoch stats
stats
=
get_training_stats
(
trainer
)
stats
=
get_training_stats
(
trainer
)
for
k
,
v
in
log_output
.
items
():
for
k
,
v
in
log_output
.
items
():
if
k
in
[
'loss'
,
'nll_loss'
]:
if
k
in
[
'loss'
,
'nll_loss'
,
'sample_size'
]:
continue
# these are already logged above
continue
# these are already logged above
if
'loss'
in
k
:
if
'loss'
in
k
:
extra_meters
[
k
].
update
(
v
,
log_output
[
'sample_size'
])
extra_meters
[
k
].
update
(
v
,
log_output
[
'sample_size'
])
...
@@ -194,6 +202,9 @@ def get_training_stats(trainer):
...
@@ -194,6 +202,9 @@ def get_training_stats(trainer):
stats
[
'gnorm'
]
=
'{:.3f}'
.
format
(
trainer
.
get_meter
(
'gnorm'
).
avg
)
stats
[
'gnorm'
]
=
'{:.3f}'
.
format
(
trainer
.
get_meter
(
'gnorm'
).
avg
)
stats
[
'clip'
]
=
'{:.0%}'
.
format
(
trainer
.
get_meter
(
'clip'
).
avg
)
stats
[
'clip'
]
=
'{:.0%}'
.
format
(
trainer
.
get_meter
(
'clip'
).
avg
)
stats
[
'oom'
]
=
trainer
.
get_meter
(
'oom'
).
avg
stats
[
'oom'
]
=
trainer
.
get_meter
(
'oom'
).
avg
if
trainer
.
get_meter
(
'loss_scale'
)
is
not
None
:
stats
[
'loss_scale'
]
=
'{:.3f}'
.
format
(
trainer
.
get_meter
(
'loss_scale'
).
avg
)
stats
[
'wall'
]
=
round
(
trainer
.
get_meter
(
'wall'
).
elapsed_time
)
return
stats
return
stats
...
@@ -234,7 +245,7 @@ def validate(args, trainer, dataset, subset, epoch):
...
@@ -234,7 +245,7 @@ def validate(args, trainer, dataset, subset, epoch):
# log mid-validation stats
# log mid-validation stats
stats
=
get_valid_stats
(
trainer
)
stats
=
get_valid_stats
(
trainer
)
for
k
,
v
in
log_output
.
items
():
for
k
,
v
in
log_output
.
items
():
if
k
in
[
'loss'
,
'nll_loss'
]:
if
k
in
[
'loss'
,
'nll_loss'
,
'sample_size'
]:
continue
continue
extra_meters
[
k
].
update
(
v
)
extra_meters
[
k
].
update
(
v
)
stats
[
k
]
=
extra_meters
[
k
].
avg
stats
[
k
]
=
extra_meters
[
k
].
avg
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment