Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
dd31fa92
"vscode:/vscode.git/clone" did not exist on "c3c94fe71b2fff2ad54698b4834dd89ad9e0e5d7"
Commit
dd31fa92
authored
Jan 16, 2018
by
Sergey Edunov
Committed by
Myle Ott
Jan 22, 2018
Browse files
Report log likelihood for label smoothing
parent
c5378602
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
3 deletions
+21
-3
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+3
-0
train.py
train.py
+18
-3
No files found.
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
dd31fa92
...
...
@@ -65,9 +65,11 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
lprobs
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
,
reduce
)
nll_loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'nll_loss'
:
nll_loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
...
...
@@ -78,4 +80,5 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
sample_size
=
sum
(
log
.
get
(
'sample_size'
,
0
)
for
log
in
logging_outputs
)
return
{
'loss'
:
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
/
sample_size
/
math
.
log
(
2
),
'nll_loss'
:
sum
(
log
.
get
(
'nll_loss'
,
0
)
for
log
in
logging_outputs
)
/
sample_size
/
math
.
log
(
2
),
}
train.py
View file @
dd31fa92
...
...
@@ -150,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
sample_without_replacement
=
args
.
sample_without_replacement
,
sort_by_source_size
=
(
epoch
<=
args
.
curriculum
))
loss_meter
=
AverageMeter
()
nll_loss_meter
=
AverageMeter
()
bsz_meter
=
AverageMeter
()
# sentences per batch
wpb_meter
=
AverageMeter
()
# words per batch
wps_meter
=
TimeMeter
()
# words per second
...
...
@@ -164,6 +165,11 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
if
'nll_loss'
in
loss_dict
:
nll_loss
=
loss_dict
[
'nll_loss'
]
nll_loss_meter
.
update
(
nll_loss
,
ntokens
)
nsentences
=
sum
(
s
[
'net_input'
][
'src_tokens'
].
size
(
0
)
for
s
in
sample
)
loss_meter
.
update
(
loss
,
nsentences
if
args
.
sentence_avg
else
ntokens
)
bsz_meter
.
update
(
nsentences
)
...
...
@@ -193,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
t
.
print
(
collections
.
OrderedDict
([
(
'train loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'train ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
(
'train ppl'
,
get_perplexity
(
nll_loss_meter
.
avg
if
nll_loss_meter
.
count
>
0
else
loss_meter
.
avg
)),
(
's/checkpoint'
,
round
(
wps_meter
.
elapsed_time
)),
(
'words/s'
,
round
(
wps_meter
.
avg
)),
(
'words/batch'
,
round
(
wpb_meter
.
avg
)),
...
...
@@ -242,16 +250,21 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
descending
=
True
,
# largest batch first to warm the caching allocator
)
loss_meter
=
AverageMeter
()
nll_loss_meter
=
AverageMeter
()
extra_meters
=
collections
.
defaultdict
(
lambda
:
AverageMeter
())
prefix
=
'valid on
\'
{}
\'
subset'
.
format
(
subset
)
with
utils
.
build_progress_bar
(
args
,
itr
,
epoch
,
prefix
)
as
t
:
for
_
,
sample
in
data
.
skip_group_enumerator
(
t
,
args
.
num_gpus
):
loss_dict
=
trainer
.
valid_step
(
sample
)
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
loss
=
loss_dict
[
'loss'
]
del
loss_dict
[
'loss'
]
# don't include in extra_meters or extra_postfix
ntokens
=
sum
(
s
[
'ntokens'
]
for
s
in
sample
)
if
'nll_loss'
in
loss_dict
:
nll_loss
=
loss_dict
[
'nll_loss'
]
nll_loss_meter
.
update
(
nll_loss
,
ntokens
)
loss_meter
.
update
(
loss
,
ntokens
)
extra_postfix
=
[]
...
...
@@ -265,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
t
.
print
(
collections
.
OrderedDict
([
(
'valid loss'
,
round
(
loss_meter
.
avg
,
2
)),
(
'valid ppl'
,
get_perplexity
(
loss_meter
.
avg
)),
(
'valid ppl'
,
get_perplexity
(
nll_loss_meter
.
avg
if
nll_loss_meter
.
count
>
0
else
loss_meter
.
avg
)),
]
+
[
(
k
,
meter
.
avg
)
for
k
,
meter
in
extra_meters
.
items
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment