Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e432459b
"vscode:/vscode.git/clone" did not exist on "234d262f39e561ee13018b33902e84b852277ec3"
Commit
e432459b
authored
Sep 27, 2017
by
Myle Ott
Browse files
Add optimizer history to checkpoints (and rearrange criterions slightly)
parent
48631f7a
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
98 additions
and
49 deletions
+98
-49
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+3
-3
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+8
-4
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+3
-3
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+35
-23
fairseq/utils.py
fairseq/utils.py
+42
-9
train.py
train.py
+7
-7
No files found.
fairseq/criterions/cross_entropy.py
View file @
e432459b
...
...
@@ -18,14 +18,14 @@ class CrossEntropyCriterion(FairseqCriterion):
super
().
__init__
()
self
.
padding_idx
=
padding_idx
def
prepare
(
self
,
samples
):
self
.
denom
=
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
def
grad_denom
(
self
,
samples
):
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
def
forward
(
self
,
net_output
,
sample
):
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
return
loss
/
self
.
denom
return
loss
def
aggregate
(
self
,
losses
):
return
sum
(
losses
)
/
math
.
log
(
2
)
fairseq/criterions/fairseq_criterion.py
View file @
e432459b
...
...
@@ -11,13 +11,17 @@ from torch.nn.modules.loss import _Loss
class
FairseqCriterion
(
_Loss
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
__init__
(
self
):
super
().
__init__
()
def
prepare
(
self
,
samples
):
"""
Prepare criterion
for DataParallel training."""
def
grad_denom
(
self
,
samples
):
"""
Gradient normalization term
for DataParallel training."""
raise
NotImplementedError
def
prepare
(
self
,
model
,
sample
):
"""Apply criterion-specific modifications to the sample."""
return
sample
def
forward
(
self
,
net_output
,
sample
):
"""Compute the loss for the given sample and network output."""
raise
NotImplementedError
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
e432459b
...
...
@@ -49,14 +49,14 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self
.
padding_idx
=
padding_idx
self
.
weights
=
weights
def
prepare
(
self
,
samples
):
self
.
denom
=
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
def
grad_denom
(
self
,
samples
):
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
def
forward
(
self
,
net_output
,
sample
):
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
return
loss
/
self
.
denom
return
loss
def
aggregate
(
self
,
losses
):
return
sum
(
losses
)
/
math
.
log
(
2
)
fairseq/multiprocessing_trainer.py
View file @
e432459b
...
...
@@ -32,7 +32,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
(prefixed with `_async_`), which run on each process in parallel.
"""
def
__init__
(
self
,
args
,
model
,
device_ids
=
None
,
def
__init__
(
self
,
args
,
model
,
criterion
,
device_ids
=
None
,
multiprocessing_method
=
'spawn'
):
if
device_ids
is
None
:
device_ids
=
tuple
(
range
(
torch
.
cuda
.
device_count
()))
...
...
@@ -42,16 +42,17 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
raise
NotImplementedError
(
'Training on CPU is not supported'
)
model
=
model
.
share_memory
()
nccl_uid
=
nccl
.
get_unique_id
()
self
.
criterion
=
criterion
Future
.
gen_list
([
self
.
call_async
(
rank
,
'_async_init'
,
args
=
args
,
model
=
model
,
nccl_uid
=
nccl_uid
)
criterion
=
criterion
,
nccl_uid
=
nccl_uid
)
for
rank
in
range
(
self
.
num_replicas
)
])
self
.
_grads_initialized
=
False
def
_async_init
(
self
,
rank
,
device_id
,
args
,
model
,
nccl_uid
):
def
_async_init
(
self
,
rank
,
device_id
,
args
,
model
,
criterion
,
nccl_uid
):
"""Initialize child processes."""
self
.
args
=
args
...
...
@@ -64,8 +65,9 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# initialize NCCL
nccl
.
initialize
(
self
.
num_replicas
,
nccl_uid
,
device_id
)
# copy model to current device
# copy model
and criterion
to current device
self
.
model
=
model
.
cuda
()
self
.
criterion
=
criterion
.
cuda
()
# initialize optimizer
self
.
optimizer
=
NAG
(
self
.
model
.
parameters
(),
lr
=
self
.
args
.
lr
,
...
...
@@ -104,8 +106,8 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
batch_offset
=
batch_offset
,
val_loss
=
val_loss
).
gen
()
def
_async_save_checkpoint
(
self
,
rank
,
device_id
,
args
,
epoch
,
batch_offset
,
val_loss
):
utils
.
save_checkpoint
(
args
,
epoch
,
batch_offset
,
self
.
model
,
self
.
optimizer
,
self
.
lr_scheduler
,
val_loss
)
utils
.
save_checkpoint
(
args
,
epoch
,
batch_offset
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
val_loss
,
self
.
_optim_history
)
def
load_checkpoint
(
self
,
filename
):
"""Load a checkpoint into the model replicas in each process."""
...
...
@@ -117,13 +119,13 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
return
epoch
,
batch_offset
def
_async_load_checkpoint
(
self
,
rank
,
device_id
,
filename
):
return
utils
.
load_checkpoint
(
filename
,
self
.
model
,
self
.
optimizer
,
self
.
lr_scheduler
,
cuda_device
=
device_id
)
epoch
,
batch_offset
,
self
.
_optim_history
=
utils
.
load_checkpoint
(
filename
,
self
.
model
,
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
cuda_device
=
device_id
)
return
epoch
,
batch_offset
def
train_step
(
self
,
samples
,
criterion
):
def
train_step
(
self
,
samples
):
"""Do forward, backward and gradient step in parallel."""
assert
isinstance
(
criterion
,
FairseqCriterion
)
# PyTorch initializes gradient buffers lazily, so the first
# train step needs to send non-empty samples to all replicas
replace_empty_samples
=
False
...
...
@@ -133,31 +135,36 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
criterion
.
prepare
(
samples
)
# calculate gradient normalization term
grad_denom
=
self
.
criterion
.
grad_denom
(
samples
)
# forward pass, backward pass and gradient step
losses
=
[
self
.
call_async
(
rank
,
'_async_train_step'
,
criterion
=
criterion
)
self
.
call_async
(
rank
,
'_async_train_step'
,
grad_denom
=
grad_denom
)
for
rank
in
range
(
self
.
num_replicas
)
]
# aggregate losses and gradient norms
losses
,
grad_norms
=
Future
.
gen_tuple_list
(
losses
)
loss
=
criterion
.
aggregate
(
losses
)
loss
=
self
.
criterion
.
aggregate
(
losses
)
return
loss
,
grad_norms
[
0
]
def
_async_train_step
(
self
,
rank
,
device_id
,
criterion
):
def
_async_train_step
(
self
,
rank
,
device_id
,
grad_denom
):
self
.
model
.
train
()
# zero grads even if
net_input
is None, since we will all-reduce them
# zero grads even if
self._sample
is None, since we will all-reduce them
self
.
optimizer
.
zero_grad
()
# calculate loss and grads
loss
=
0
if
self
.
_sample
is
not
None
:
self
.
_sample
=
self
.
criterion
.
prepare
(
self
.
model
,
self
.
_sample
)
net_output
=
self
.
model
(
**
self
.
_sample
[
'net_input'
])
loss_
=
criterion
(
net_output
,
self
.
_sample
)
loss_
=
self
.
criterion
(
net_output
,
self
.
_sample
)
if
grad_denom
is
not
None
:
loss_
/=
grad_denom
loss_
.
backward
()
loss
=
loss_
.
data
[
0
]
...
...
@@ -196,29 +203,34 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
flat_grads
.
div_
(
coef
)
return
norm
def
valid_step
(
self
,
samples
,
criterion
):
def
valid_step
(
self
,
samples
):
"""Do forward pass in parallel."""
# scatter sample across GPUs
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
criterion
.
prepare
(
samples
)
# calculate gradient normalization term
grad_denom
=
self
.
criterion
.
grad_denom
(
samples
)
# forward pass
losses
=
[
self
.
call_async
(
rank
,
'_async_valid_step'
,
criterion
=
criterion
)
self
.
call_async
(
rank
,
'_async_valid_step'
,
grad_denom
=
grad_denom
)
for
rank
in
range
(
self
.
num_replicas
)
]
# aggregate losses
loss
=
criterion
.
aggregate
(
Future
.
gen_list
(
losses
))
loss
=
self
.
criterion
.
aggregate
(
Future
.
gen_list
(
losses
))
return
loss
def
_async_valid_step
(
self
,
rank
,
device_id
,
criterion
):
def
_async_valid_step
(
self
,
rank
,
device_id
,
grad_denom
):
if
self
.
_sample
is
None
:
return
0
self
.
model
.
eval
()
self
.
_sample
=
self
.
criterion
.
prepare
(
self
.
model
,
self
.
_sample
)
net_output
=
self
.
model
(
**
self
.
_sample
[
'net_input'
])
loss
=
criterion
(
net_output
,
self
.
_sample
)
loss
=
self
.
criterion
(
net_output
,
self
.
_sample
)
if
grad_denom
is
not
None
:
loss
/=
grad_denom
return
loss
.
data
[
0
]
def
get_lr
(
self
):
...
...
fairseq/utils.py
View file @
e432459b
...
...
@@ -46,15 +46,23 @@ def torch_persistent_save(*args, **kwargs):
logging
.
error
(
traceback
.
format_exc
())
def
save_checkpoint
(
args
,
epoch
,
batch_offset
,
model
,
optimizer
,
lr_scheduler
,
val_loss
=
None
):
def
save_checkpoint
(
args
,
epoch
,
batch_offset
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
val_loss
=
None
,
optim_history
=
None
):
if
optim_history
is
None
:
optim_history
=
[]
state_dict
=
{
'args'
:
args
,
'epoch'
:
epoch
,
'batch_offset'
:
batch_offset
,
'model'
:
model
.
state_dict
(),
'val_loss'
:
val_loss
,
'optimizer_history'
:
optim_history
+
[
{
'criterion_name'
:
criterion
.
__class__
.
__name__
,
'optimizer'
:
optimizer
.
state_dict
(),
'best_loss'
:
lr_scheduler
.
best
,
'val_loss'
:
val_loss
,
}
],
}
if
batch_offset
==
0
:
...
...
@@ -72,9 +80,9 @@ def save_checkpoint(args, epoch, batch_offset, model, optimizer, lr_scheduler, v
torch_persistent_save
(
state_dict
,
last_filename
)
def
load_checkpoint
(
filename
,
model
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
def
load_checkpoint
(
filename
,
model
,
criterion
,
optimizer
,
lr_scheduler
,
cuda_device
=
None
):
if
not
os
.
path
.
exists
(
filename
):
return
1
,
0
return
1
,
0
,
[]
if
cuda_device
is
None
:
state
=
torch
.
load
(
filename
)
else
:
...
...
@@ -82,16 +90,41 @@ def load_checkpoint(filename, model, optimizer, lr_scheduler, cuda_device=None):
filename
,
map_location
=
lambda
s
,
l
:
default_restore_location
(
s
,
'cuda:{}'
.
format
(
cuda_device
))
)
state
=
_upgrade_state_dict
(
state
)
model
.
load_state_dict
(
state
[
'model'
])
optimizer
.
load_state_dict
(
state
[
'optimizer'
])
lr_scheduler
.
best
=
state
[
'best_loss'
]
epoch
=
state
[
'epoch'
]
+
1
batch_offset
=
state
[
'batch_offset'
]
# only load optimizer and lr_scheduler if they match with the checkpoint
opt_str
=
''
optim_history
=
state
[
'optimizer_history'
]
last_optim
=
optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
criterion
.
__class__
.
__name__
:
optimizer
.
load_state_dict
(
last_optim
[
'optimizer'
])
lr_scheduler
.
best
=
last_optim
[
'best_loss'
]
opt_str
=
'; criterion: {}'
.
format
(
last_optim
[
'criterion_name'
])
gpu_str
=
' on GPU #{}'
.
format
(
cuda_device
)
if
cuda_device
is
not
None
else
''
print
(
'| loaded checkpoint {} (epoch {}){}'
.
format
(
filename
,
epoch
,
gpu_str
))
return
epoch
,
batch_offset
print
(
'| loaded checkpoint {} (epoch {}{}){}'
.
format
(
filename
,
epoch
,
opt_str
,
gpu_str
))
return
epoch
,
batch_offset
,
optim_history
def
_upgrade_state_dict
(
state
):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if
'optimizer_history'
not
in
state
:
state
[
'optimizer_history'
]
=
[
{
'criterion_name'
:
criterions
.
CrossEntropyCriterion
.
__name__
,
'optimizer'
:
state
[
'optimizer'
],
'best_loss'
:
state
[
'best_loss'
],
},
]
del
state
[
'optimizer'
]
del
state
[
'best_loss'
]
return
state
def
load_ensemble_for_inference
(
filenames
,
data_path
,
split
):
...
...
train.py
View file @
e432459b
...
...
@@ -68,7 +68,7 @@ def main():
criterion
=
utils
.
build_criterion
(
args
,
dataset
)
# Start multiprocessing
trainer
=
MultiprocessingTrainer
(
args
,
model
)
trainer
=
MultiprocessingTrainer
(
args
,
model
,
criterion
)
# Load the latest checkpoint if one is available
epoch
,
batch_offset
=
trainer
.
load_checkpoint
(
os
.
path
.
join
(
args
.
save_dir
,
args
.
restore_file
))
...
...
@@ -81,11 +81,11 @@ def main():
train_meter
.
start
()
while
lr
>
args
.
min_lr
and
epoch
<=
max_epoch
:
# train for one epoch
train
(
args
,
epoch
,
batch_offset
,
trainer
,
criterion
,
dataset
,
num_gpus
)
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
)
# evaluate on validate set
for
k
,
subset
in
enumerate
(
args
.
valid_subset
.
split
(
','
)):
val_loss
=
validate
(
args
,
epoch
,
trainer
,
criterion
,
dataset
,
subset
,
num_gpus
)
val_loss
=
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
num_gpus
)
if
k
==
0
:
if
not
args
.
no_save
:
# save checkpoint
...
...
@@ -102,7 +102,7 @@ def main():
trainer
.
stop
()
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
criterion
,
dataset
,
num_gpus
):
def
train
(
args
,
epoch
,
batch_offset
,
trainer
,
dataset
,
num_gpus
):
"""Train the model for one epoch."""
itr
=
dataset
.
dataloader
(
args
.
train_subset
,
num_workers
=
args
.
workers
,
...
...
@@ -121,7 +121,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
lr
=
trainer
.
get_lr
()
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
loss
,
grad_norm
=
trainer
.
train_step
(
sample
,
criterion
)
loss
,
grad_norm
=
trainer
.
train_step
(
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
src_size
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
...
...
@@ -160,7 +160,7 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
gnorm_meter
.
avg
))
def
validate
(
args
,
epoch
,
trainer
,
criterion
,
dataset
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
"""Evaluate the model on the validation set and return the average loss."""
itr
=
dataset
.
dataloader
(
subset
,
batch_size
=
None
,
...
...
@@ -173,7 +173,7 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
ngpus
):
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
loss
=
trainer
.
valid_step
(
sample
,
criterion
)
loss
=
trainer
.
valid_step
(
sample
)
loss_meter
.
update
(
loss
,
ntokens
)
t
.
set_postfix
(
loss
=
'{:.2f}'
.
format
(
loss_meter
.
avg
),
refresh
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment