Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
8bafae2e
Commit
8bafae2e
authored
Sep 28, 2017
by
Myle Ott
Browse files
Better logging from criterions
parent
e432459b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
62 deletions
+80
-62
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+10
-5
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+3
-11
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+10
-5
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+14
-21
train.py
train.py
+43
-20
No files found.
fairseq/criterions/cross_entropy.py
View file @
8bafae2e
...
@@ -21,11 +21,16 @@ class CrossEntropyCriterion(FairseqCriterion):
...
@@ -21,11 +21,16 @@ class CrossEntropyCriterion(FairseqCriterion):
def
grad_denom
(
self
,
samples
):
def
grad_denom
(
self
,
samples
):
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
def
forward
(
self
,
net_output
,
sample
):
def
forward
(
self
,
model
,
sample
,
grad_denom
):
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
return
loss
return
{
'loss'
:
loss
/
grad_denom
,
def
aggregate
(
self
,
losses
):
}
return
sum
(
losses
)
/
math
.
log
(
2
)
def
aggregate
(
self
,
loss_dicts
):
return
{
'loss'
:
sum
(
l
[
'loss'
].
data
[
0
]
for
l
in
loss_dicts
if
'loss'
in
l
)
/
math
.
log
(
2
),
}
fairseq/criterions/fairseq_criterion.py
View file @
8bafae2e
...
@@ -18,18 +18,10 @@ class FairseqCriterion(_Loss):
...
@@ -18,18 +18,10 @@ class FairseqCriterion(_Loss):
"""Gradient normalization term for DataParallel training."""
"""Gradient normalization term for DataParallel training."""
raise
NotImplementedError
raise
NotImplementedError
def
prepare
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
,
grad_denom
):
"""Apply criterion-specific modifications to the sample."""
return
sample
def
forward
(
self
,
net_output
,
sample
):
"""Compute the loss for the given sample and network output."""
"""Compute the loss for the given sample and network output."""
raise
NotImplementedError
raise
NotImplementedError
def
aggregate
(
self
,
losses
):
def
aggregate
(
self
,
losses
,
log_infos
):
"""Aggregate losses from DataParallel training.
"""Aggregate losses from DataParallel training."""
Takes a list of losses as input (as returned by forward) and
aggregates them into the total loss for the mini-batch.
"""
raise
NotImplementedError
raise
NotImplementedError
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
8bafae2e
...
@@ -52,11 +52,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -52,11 +52,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def
grad_denom
(
self
,
samples
):
def
grad_denom
(
self
,
samples
):
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
return
sum
(
s
[
'ntokens'
]
if
s
else
0
for
s
in
samples
)
def
forward
(
self
,
net_output
,
sample
):
def
forward
(
self
,
model
,
sample
,
grad_denom
):
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)))
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
loss
=
LabelSmoothedCrossEntropy
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
return
loss
return
{
'loss'
:
loss
/
grad_denom
,
def
aggregate
(
self
,
losses
):
}
return
sum
(
losses
)
/
math
.
log
(
2
)
def
aggregate
(
self
,
loss_dicts
):
return
{
'loss'
:
sum
(
l
[
'loss'
].
data
[
0
]
for
l
in
loss_dicts
if
'loss'
in
l
)
/
math
.
log
(
2
),
}
fairseq/multiprocessing_trainer.py
View file @
8bafae2e
...
@@ -146,10 +146,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -146,10 +146,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
]
]
# aggregate losses and gradient norms
# aggregate losses and gradient norms
losses
,
grad_norms
=
Future
.
gen_tuple_list
(
losses
)
loss_dicts
=
Future
.
gen_list
(
losses
)
loss
=
self
.
criterion
.
aggregate
(
losses
)
loss_dict
=
self
.
criterion
.
aggregate
(
loss_dicts
)
loss_dict
[
'gnorm'
]
=
loss_dicts
[
0
][
'gnorm'
]
return
loss
,
grad_norms
[
0
]
return
loss
_dict
def
_async_train_step
(
self
,
rank
,
device_id
,
grad_denom
):
def
_async_train_step
(
self
,
rank
,
device_id
,
grad_denom
):
self
.
model
.
train
()
self
.
model
.
train
()
...
@@ -159,14 +160,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -159,14 +160,11 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
# calculate loss and grads
# calculate loss and grads
loss
=
0
loss
=
0
loss_dict
=
{}
if
self
.
_sample
is
not
None
:
if
self
.
_sample
is
not
None
:
self
.
_sample
=
self
.
criterion
.
prepare
(
self
.
model
,
self
.
_sample
)
loss_dict
=
self
.
criterion
(
self
.
model
,
self
.
_sample
,
grad_denom
)
net_output
=
self
.
model
(
**
self
.
_sample
[
'net_input'
])
loss_dict
[
'loss'
].
backward
()
loss_
=
self
.
criterion
(
net_output
,
self
.
_sample
)
loss
=
loss_dict
[
'loss'
].
data
[
0
]
if
grad_denom
is
not
None
:
loss_
/=
grad_denom
loss_
.
backward
()
loss
=
loss_
.
data
[
0
]
# flatten grads into a contiguous block of memory
# flatten grads into a contiguous block of memory
if
self
.
flat_grads
is
None
:
if
self
.
flat_grads
is
None
:
...
@@ -176,12 +174,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -176,12 +174,12 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
nccl
.
all_reduce
(
self
.
flat_grads
)
nccl
.
all_reduce
(
self
.
flat_grads
)
# clip grads
# clip grads
grad_
norm
=
self
.
_clip_grads_
(
self
.
flat_grads
,
self
.
args
.
clip_norm
)
loss_dict
[
'g
norm
'
]
=
self
.
_clip_grads_
(
self
.
flat_grads
,
self
.
args
.
clip_norm
)
# take an optimization step
# take an optimization step
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
return
loss
,
grad_norm
return
loss
_dict
def
_flatten_grads_
(
self
,
model
):
def
_flatten_grads_
(
self
,
model
):
num_params
=
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())
num_params
=
sum
(
p
.
data
.
numel
()
for
p
in
model
.
parameters
())
...
@@ -218,20 +216,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -218,20 +216,15 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
]
]
# aggregate losses
# aggregate losses
loss
=
self
.
criterion
.
aggregate
(
Future
.
gen_list
(
losses
))
loss
_dict
=
self
.
criterion
.
aggregate
(
Future
.
gen_list
(
losses
))
return
loss
return
loss
_dict
def
_async_valid_step
(
self
,
rank
,
device_id
,
grad_denom
):
def
_async_valid_step
(
self
,
rank
,
device_id
,
grad_denom
):
if
self
.
_sample
is
None
:
if
self
.
_sample
is
None
:
return
0
return
{}
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
_sample
=
self
.
criterion
.
prepare
(
self
.
model
,
self
.
_sample
)
return
self
.
criterion
(
self
.
model
,
self
.
_sample
,
grad_denom
)
net_output
=
self
.
model
(
**
self
.
_sample
[
'net_input'
])
loss
=
self
.
criterion
(
net_output
,
self
.
_sample
)
if
grad_denom
is
not
None
:
loss
/=
grad_denom
return
loss
.
data
[
0
]
def
get_lr
(
self
):
def
get_lr
(
self
):
"""Get the current learning rate."""
"""Get the current learning rate."""
...
...
train.py
View file @
8bafae2e
...
@@ -115,13 +115,15 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
...
@@ -115,13 +115,15 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
wpb_meter
=
AverageMeter
()
# words per batch
wpb_meter
=
AverageMeter
()
# words per batch
wps_meter
=
TimeMeter
()
# words per second
wps_meter
=
TimeMeter
()
# words per second
clip_meter
=
AverageMeter
()
# % of updates clipped
clip_meter
=
AverageMeter
()
# % of updates clipped
gnorm
_meter
=
AverageMeter
()
# gradient norm
extra
_meter
s
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
desc
=
'| epoch {:03d}'
.
format
(
epoch
)
lr
=
trainer
.
get_lr
()
lr
=
trainer
.
get_lr
()
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
for
i
,
sample
in
data
.
skip_group_enumerator
(
t
,
num_gpus
,
batch_offset
):
loss
,
grad_norm
=
trainer
.
train_step
(
sample
)
loss_dict
=
trainer
.
train_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
src_size
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
src_size
=
sum
(
s
[
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
...
@@ -129,8 +131,12 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
...
@@ -129,8 +131,12 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
bsz_meter
.
update
(
src_size
)
bsz_meter
.
update
(
src_size
)
wpb_meter
.
update
(
ntokens
)
wpb_meter
.
update
(
ntokens
)
wps_meter
.
update
(
ntokens
)
wps_meter
.
update
(
ntokens
)
clip_meter
.
update
(
1
if
grad_norm
>
args
.
clip_norm
else
0
)
clip_meter
.
update
(
1
if
loss_dict
[
'gnorm'
]
>
args
.
clip_norm
else
0
)
gnorm_meter
.
update
(
grad_norm
)
extra_postfix
=
[]
for
k
,
v
in
loss_dict
.
items
():
extra_meters
[
k
].
update
(
v
)
extra_postfix
.
append
((
k
,
'{:.4f}'
.
format
(
extra_meters
[
k
].
avg
)))
t
.
set_postfix
(
collections
.
OrderedDict
([
t
.
set_postfix
(
collections
.
OrderedDict
([
(
'loss'
,
'{:.2f} ({:.2f})'
.
format
(
loss
,
loss_meter
.
avg
)),
(
'loss'
,
'{:.2f} ({:.2f})'
.
format
(
loss
,
loss_meter
.
avg
)),
...
@@ -139,8 +145,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
...
@@ -139,8 +145,7 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
(
'bsz'
,
'{:5d}'
.
format
(
round
(
bsz_meter
.
avg
))),
(
'bsz'
,
'{:5d}'
.
format
(
round
(
bsz_meter
.
avg
))),
(
'lr'
,
lr
),
(
'lr'
,
lr
),
(
'clip'
,
'{:3.0f}%'
.
format
(
clip_meter
.
avg
*
100
)),
(
'clip'
,
'{:3.0f}%'
.
format
(
clip_meter
.
avg
*
100
)),
(
'gnorm'
,
'{:.4f}'
.
format
(
gnorm_meter
.
avg
)),
]
+
extra_postfix
),
refresh
=
False
)
]),
refresh
=
False
)
if
i
==
0
:
if
i
==
0
:
# ignore the first mini-batch in words-per-second calculation
# ignore the first mini-batch in words-per-second calculation
...
@@ -148,16 +153,17 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
...
@@ -148,16 +153,17 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
if
args
.
save_interval
>
0
and
(
i
+
1
)
%
args
.
save_interval
==
0
:
trainer
.
save_checkpoint
(
args
,
epoch
,
i
+
1
)
trainer
.
save_checkpoint
(
args
,
epoch
,
i
+
1
)
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
fmt
=
desc
+
' | train loss {:2.2f} | train ppl {:3.2f}'
.
format
(
fmt
+=
' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
loss_meter
.
avg
,
math
.
pow
(
2
,
loss_meter
.
avg
))
fmt
+=
' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
fmt
+=
' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
.
format
(
t
.
write
(
fmt
.
format
(
loss_meter
.
avg
,
math
.
pow
(
2
,
loss_meter
.
avg
),
round
(
wps_meter
.
elapsed_time
),
round
(
wps_meter
.
avg
),
round
(
wpb_meter
.
avg
))
round
(
wps_meter
.
elapsed_time
),
fmt
+=
' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'
.
format
(
round
(
wps_meter
.
avg
),
round
(
bsz_meter
.
avg
),
lr
,
clip_meter
.
avg
*
100
)
round
(
wpb_meter
.
avg
),
fmt
+=
''
.
join
(
round
(
bsz_meter
.
avg
),
' | {} {:.4f}'
.
format
(
k
,
meter
.
avg
)
lr
,
clip_meter
.
avg
*
100
,
for
k
,
meter
in
extra_meters
.
items
()
gnorm_meter
.
avg
))
)
t
.
write
(
fmt
)
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
def
validate
(
args
,
epoch
,
trainer
,
dataset
,
subset
,
ngpus
):
...
@@ -168,18 +174,35 @@ def validate(args, epoch, trainer, dataset, subset, ngpus):
...
@@ -168,18 +174,35 @@ def validate(args, epoch, trainer, dataset, subset, ngpus):
max_positions
=
args
.
max_positions
,
max_positions
=
args
.
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
desc
=
'| epoch {:03d} | valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
desc
=
'| epoch {:03d} | valid on
\'
{}
\'
subset'
.
format
(
epoch
,
subset
)
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
with
progress_bar
(
itr
,
desc
,
leave
=
False
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
ngpus
):
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
ngpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
loss
=
trainer
.
valid_step
(
sample
)
loss_meter
.
update
(
loss
,
ntokens
)
loss_meter
.
update
(
loss
,
ntokens
)
t
.
set_postfix
(
loss
=
'{:.2f}'
.
format
(
loss_meter
.
avg
),
refresh
=
False
)
extra_postfix
=
[]
for
k
,
v
in
loss_dict
.
items
():
extra_meters
[
k
].
update
(
v
)
extra_postfix
.
append
((
k
,
'{:.4f}'
.
format
(
extra_meters
[
k
].
avg
)))
t
.
set_postfix
(
collections
.
OrderedDict
([
(
'loss'
,
'{:.2f}'
.
format
(
loss_meter
.
avg
)),
]
+
extra_postfix
),
refresh
=
False
)
val_loss
=
loss_meter
.
avg
val_loss
=
loss_meter
.
avg
t
.
write
(
desc
+
' | valid loss {:2.2f} | valid ppl {:3.2f}'
fmt
=
desc
+
' | valid loss {:2.2f} | valid ppl {:3.2f}'
.
format
(
.
format
(
val_loss
,
math
.
pow
(
2
,
val_loss
)))
val_loss
,
math
.
pow
(
2
,
val_loss
))
fmt
+=
''
.
join
(
' | {} {:.4f}'
.
format
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
)
t
.
write
(
fmt
)
# update and return the learning rate
# update and return the learning rate
return
val_loss
return
val_loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment