Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
3f970086
Commit
3f970086
authored
Oct 06, 2017
by
Myle Ott
Browse files
More flexible gradient normalization
parent
88a8bd42
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
98 additions
and
63 deletions
+98
-63
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+18
-8
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+15
-7
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+18
-8
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+47
-40
No files found.
fairseq/criterions/cross_entropy.py
View file @
3f970086
...
@@ -18,19 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion):
...
@@ -18,19 +18,29 @@ class CrossEntropyCriterion(FairseqCriterion):
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
padding_idx
self
.
padding_idx
=
padding_idx
def
grad_denom
(
self
,
samples
):
def
forward
(
self
,
model
,
sample
):
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
"""Compute the loss for the given sample.
def
forward
(
self
,
model
,
sample
,
grad_denom
):
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
'net_input'
])
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
return
{
sample_size
=
sample
[
'ntokens'
]
'loss'
:
loss
/
grad_denom
,
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'sample_size'
:
sample_size
,
}
}
return
loss
,
sample_size
,
logging_output
def
aggregate
(
self
,
loss_dicts
):
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
sample_size
=
sum
(
log
.
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
)
return
{
return
{
'loss'
:
sum
(
l
[
'loss'
].
data
[
0
]
for
l
in
lo
ss_dicts
if
'loss'
in
l
)
/
math
.
log
(
2
),
'loss'
:
sum
(
l
og
.
get
(
'loss'
,
0
)
for
l
og
in
lo
gging_outputs
)
/
sample_size
/
math
.
log
(
2
),
}
}
fairseq/criterions/fairseq_criterion.py
View file @
3f970086
...
@@ -14,14 +14,22 @@ class FairseqCriterion(_Loss):
...
@@ -14,14 +14,22 @@ class FairseqCriterion(_Loss):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
def
grad_denom
(
self
,
samples
):
def
forward
(
self
,
model
,
sample
):
"""Gradient normalization term for DataParallel training."""
"""Compute the loss for the given sample.
raise
NotImplementedError
def
forward
(
self
,
model
,
sample
,
grad_denom
):
Returns a tuple with three elements:
"""Compute the loss for the given sample and network output."""
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
raise
NotImplementedError
raise
NotImplementedError
def
aggregate
(
self
,
losses
,
log_infos
):
@
staticmethod
"""Aggregate losses from DataParallel training."""
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
grad_denom
(
sample_sizes
):
"""Compute the gradient denominator for a set of sample sizes."""
return
sum
(
sample_sizes
)
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
3f970086
...
@@ -49,19 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -49,19 +49,29 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self
.
padding_idx
=
padding_idx
self
.
padding_idx
=
padding_idx
self
.
weights
=
weights
self
.
weights
=
weights
def
grad_denom
(
self
,
samples
):
def
forward
(
self
,
model
,
sample
):
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
"""Compute the loss for the given sample.
def
forward
(
self
,
model
,
sample
,
grad_denom
):
Returns a tuple with three elements:
1) the loss, as a Variable
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
'net_input'
])
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
return
{
sample_size
=
sample
[
'ntokens'
]
'loss'
:
loss
/
grad_denom
,
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'sample_size'
:
sample_size
,
}
}
return
loss
,
sample_size
,
logging_output
def
aggregate
(
self
,
loss_dicts
):
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
"""Aggregate logging outputs from data parallel training."""
sample_size
=
sum
(
log
.
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
)
return
{
return
{
'loss'
:
sum
(
l
[
'loss'
].
data
[
0
]
for
l
in
lo
ss_dicts
if
'loss'
in
l
)
/
math
.
log
(
2
),
'loss'
:
sum
(
l
og
.
get
(
'loss'
,
0
)
for
l
og
in
lo
gging_outputs
)
/
sample_size
/
math
.
log
(
2
),
}
}
fairseq/multiprocessing_trainer.py
View file @
3f970086
...
@@ -15,7 +15,6 @@ import torch
...
@@ -15,7 +15,6 @@ import torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
fairseq
import
nccl
,
utils
from
fairseq
import
nccl
,
utils
from
fairseq.criterions
import
FairseqCriterion
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.nag
import
NAG
from
fairseq.nag
import
NAG
...
@@ -74,6 +73,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -74,6 +73,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
momentum
=
self
.
args
.
momentum
,
momentum
=
self
.
args
.
momentum
,
weight_decay
=
self
.
args
.
weight_decay
)
weight_decay
=
self
.
args
.
weight_decay
)
self
.
flat_grads
=
None
self
.
flat_grads
=
None
self
.
loss
=
None
# initialize LR scheduler
# initialize LR scheduler
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
self
.
lr_scheduler
=
self
.
_build_lr_scheduler
()
...
@@ -136,35 +136,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -136,35 +136,44 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
# scatter sample across GPUs
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
self
.
_scatter_samples
(
samples
,
replace_empty_samples
=
replace_empty_samples
)
# calculate gradient normalization term
# forward pass
grad_denom
=
self
.
criterion
.
grad_denom
(
samples
)
sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
([
self
.
call_async
(
rank
,
'_async_forward'
)
for
rank
in
range
(
self
.
num_replicas
)
])
# forward pass, backward pass and gradient step
# backward pass, all-reduce gradients and take an optimization step
losses
=
[
grad_denom
=
self
.
criterion
.
__class__
.
grad_denom
(
sample_sizes
)
self
.
call_async
(
rank
,
'_async_train_step'
,
grad_denom
=
grad_denom
)
grad_norms
=
Future
.
gen_list
([
self
.
call_async
(
rank
,
'_async_backward_and_opt'
,
grad_denom
=
grad_denom
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
in
range
(
self
.
num_replicas
)
]
]
)
# aggregate losses and gradient norms
# aggregate logging output
loss_dicts
=
Future
.
gen_list
(
losses
)
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
loss_dict
=
self
.
criterion
.
aggregate
(
loss_dicts
)
logging_output
[
'gnorm'
]
=
grad_norms
[
0
]
# log the gradient norm
loss_dict
[
'gnorm'
]
=
loss_dicts
[
0
][
'gnorm'
]
return
logging_output
def
_async_forward
(
self
,
rank
,
device_id
,
eval
=
False
):
if
eval
:
self
.
model
.
eval
()
else
:
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
return
loss_dict
if
self
.
_sample
is
None
:
return
0
,
{}
def
_async_train_step
(
self
,
rank
,
device_id
,
grad_denom
):
# calculate loss and sample size
self
.
model
.
train
(
)
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
# zero grads even if self._sample is None, since we will all-reduce them
return
sample_size
,
logging_output
self
.
optimizer
.
zero_grad
()
# calculate loss and grads
def
_async_backward_and_opt
(
self
,
rank
,
device_id
,
grad_denom
):
loss
=
0
if
self
.
loss
is
not
None
:
loss_dict
=
{}
# backward pass
if
self
.
_sample
is
not
None
:
self
.
loss
.
backward
()
loss_dict
=
self
.
criterion
(
self
.
model
,
self
.
_sample
,
grad_denom
)
loss_dict
[
'loss'
].
backward
()
loss
=
loss_dict
[
'loss'
].
data
[
0
]
# flatten grads into a contiguous block of memory
# flatten grads into a contiguous block of memory
if
self
.
flat_grads
is
None
:
if
self
.
flat_grads
is
None
:
...
@@ -173,13 +182,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -173,13 +182,20 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# all-reduce grads
# all-reduce grads
nccl
.
all_reduce
(
self
.
flat_grads
)
nccl
.
all_reduce
(
self
.
flat_grads
)
# normalize grads
if
grad_denom
!=
0
:
self
.
flat_grads
.
div_
(
grad_denom
)
# clip grads
# clip grads
loss_dict
[
'g
norm
'
]
=
self
.
_clip_grads_
(
self
.
flat_grads
,
self
.
args
.
clip_norm
)
grad_
norm
=
self
.
_clip_grads_
(
self
.
flat_grads
,
self
.
args
.
clip_norm
)
# take an optimization step
# take an optimization step
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
return
loss_dict
# reset loss
self
.
loss
=
None
return
grad_norm
def
_flatten_grads_
(
self
,
model
):
def
_flatten_grads_
(
self
,
model
):
num_params
=
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())
num_params
=
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())
...
@@ -206,25 +222,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -206,25 +222,16 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# scatter sample across GPUs
# scatter sample across GPUs
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
self
.
_scatter_samples
(
samples
,
volatile
=
True
)
# calculate gradient normalization term
grad_denom
=
self
.
criterion
.
grad_denom
(
samples
)
# forward pass
# forward pass
losses
=
[
_sample_sizes
,
logging_outputs
=
Future
.
gen_tuple_list
(
[
self
.
call_async
(
rank
,
'_async_
valid_step'
,
grad_denom
=
grad_denom
)
self
.
call_async
(
rank
,
'_async_
forward'
,
eval
=
True
)
for
rank
in
range
(
self
.
num_replicas
)
for
rank
in
range
(
self
.
num_replicas
)
]
])
# aggregate losses
loss_dict
=
self
.
criterion
.
aggregate
(
Future
.
gen_list
(
losses
))
return
loss_dict
# aggregate logging output
logging_output
=
self
.
criterion
.
__class__
.
aggregate_logging_outputs
(
logging_outputs
)
def
_async_valid_step
(
self
,
rank
,
device_id
,
grad_denom
):
return
logging_output
if
self
.
_sample
is
None
:
return
{}
self
.
model
.
eval
()
return
self
.
criterion
(
self
.
model
,
self
.
_sample
,
grad_denom
)
def
get_lr
(
self
):
def
get_lr
(
self
):
"""Get the current learning rate."""
"""Get the current learning rate."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment