Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
dd31fa92
Commit
dd31fa92
authored
Jan 16, 2018
by
Sergey Edunov
Committed by
Myle Ott
Jan 22, 2018
Browse files
Report log likelihood for label smoothing
parent
c5378602
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
3 deletions
+21
-3
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+3
-0
train.py
train.py
+18
-3
No files found.
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
dd31fa92
...
@@ -65,9 +65,11 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -65,9 +65,11 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
lprobs
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
,
reduce
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
lprobs
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
,
reduce
)
nll_loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
logging_output
=
{
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'nll_loss'
:
nll_loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'sample_size'
:
sample_size
,
'sample_size'
:
sample_size
,
}
}
return
loss
,
sample_size
,
logging_output
return
loss
,
sample_size
,
logging_output
...
@@ -78,4 +80,5 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -78,4 +80,5 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
sample_size
=
sum
(
log
.
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
)
return
{
return
{
'loss'
:
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
/
sample_size
/
math
.
log
(
2
),
'loss'
:
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
/
sample_size
/
math
.
log
(
2
),
'nll_loss'
:
sum
(
log
.
get
(
'nll_loss'
,
0
)
for
log
in
logging_outputs
)
/
sample_size
/
math
.
log
(
2
),
}
}
train.py
View file @
dd31fa92
...
@@ -150,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
...
@@ -150,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
sample_without_replacement
=
args
.
sample_without_replacement
,
sample_without_replacement
=
args
.
sample_without_replacement
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
nll_loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
wpb_meter
=
AverageMeter
()
# words per batch
wps_meter
=
TimeMeter
()
# words per second
wps_meter
=
TimeMeter
()
# words per second
...
@@ -164,6 +165,11 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
...
@@ -164,6 +165,11 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
if
'nll_loss'
in
loss_dict
:
nll_loss
=
loss_dict
[
'nll_loss'
]
nll_loss_meter
.
update
(
nll_loss
,
ntokens
)
nsentences
=
sum
(
s
[
'net_input'
][
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
nsentences
=
sum
(
s
[
'net_input'
][
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
nsentences
if
args
.
sentence_avg
else
ntokens
)
loss_meter
.
update
(
loss
,
nsentences
if
args
.
sentence_avg
else
ntokens
)
bsz_meter
.
update
(
nsentences
)
bsz_meter
.
update
(
nsentences
)
...
@@ -193,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
...
@@ -193,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
t
.
print
(
collections
.
OrderedDict
([
t
.
print
(
collections
.
OrderedDict
([
(
'train loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'train loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'train ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
(
'train ppl'
,
get_perplexity
(
nll_loss_meter
.
avg
if
nll_loss_meter
.
count
>
0
else
loss_meter
.
avg
)),
(
's/checkpoint'
,
round
(
wps_meter
.
elapsed_time
)),
(
's/checkpoint'
,
round
(
wps_meter
.
elapsed_time
)),
(
'words/s'
,
round
(
wps_meter
.
avg
)),
(
'words/s'
,
round
(
wps_meter
.
avg
)),
(
'words/batch'
,
round
(
wpb_meter
.
avg
)),
(
'words/batch'
,
round
(
wpb_meter
.
avg
)),
...
@@ -242,16 +250,21 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
...
@@ -242,16 +250,21 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
descending
=
True
,
# largest batch first to warm the caching allocator
descending
=
True
,
# largest batch first to warm the caching allocator
)
)
loss_meter
=
AverageMeter
()
loss_meter
=
AverageMeter
()
nll_loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
)
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
)
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_gpus
):
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_gpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
loss_dict
=
trainer
.
valid_step
(
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
loss
=
loss_dict
[
'loss'
]
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
if
'nll_loss'
in
loss_dict
:
nll_loss
=
loss_dict
[
'nll_loss'
]
nll_loss_meter
.
update
(
nll_loss
,
ntokens
)
loss_meter
.
update
(
loss
,
ntokens
)
loss_meter
.
update
(
loss
,
ntokens
)
extra_postfix
=
[]
extra_postfix
=
[]
...
@@ -265,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
...
@@ -265,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
t
.
print
(
collections
.
OrderedDict
([
t
.
print
(
collections
.
OrderedDict
([
(
'valid loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'valid loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'valid ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
(
'valid ppl'
,
get_perplexity
(
nll_loss_meter
.
avg
if
nll_loss_meter
.
count
>
0
else
loss_meter
.
avg
)),
]
+
[
]
+
[
(
k
,
meter
.
avg
)
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
for
k
,
meter
in
extra_meters
.
items
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment