Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e73fddf4
Commit
e73fddf4
authored
Mar 04, 2018
by
Myle Ott
Committed by
Sergey Edunov
Mar 05, 2018
Browse files
Filter padding properly in LabelSmoothedCrossEntropyCriterion (#229)
parent
5f29d123
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
108 additions
and
66 deletions
+108
-66
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+10
-40
tests/test_label_smoothing.py
tests/test_label_smoothing.py
+84
-17
tests/utils.py
tests/utils.py
+14
-9
No files found.
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
e73fddf4
...
...
@@ -7,7 +7,6 @@
import
math
import
torch
from
torch.autograd
import
Variable
import
torch.nn.functional
as
F
from
fairseq
import
utils
...
...
@@ -15,41 +14,6 @@ from fairseq import utils
from
.
import
FairseqCriterion
,
register_criterion
class
LabelSmoothedNLLLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
target
,
eps
,
padding_idx
,
weights
,
reduce
=
True
):
grad_input
=
input
.
new
(
input
.
size
()).
zero_
()
target
=
target
.
view
(
target
.
size
(
0
),
1
)
grad_input
=
grad_input
.
scatter_
(
grad_input
.
dim
()
-
1
,
target
,
eps
-
1
)
norm
=
grad_input
.
size
(
-
1
)
if
weights
is
not
None
:
if
isinstance
(
grad_input
,
Variable
)
and
not
isinstance
(
weights
,
Variable
):
weights
=
Variable
(
weights
,
requires_grad
=
False
)
norm
=
weights
.
sum
()
grad_input
.
mul
(
weights
.
view
(
1
,
weights
.
size
(
0
)).
expand_as
(
grad_input
))
if
padding_idx
is
not
None
:
norm
-=
1
if
weights
is
None
else
weights
[
padding_idx
]
grad_input
.
select
(
grad_input
.
dim
()
-
1
,
padding_idx
).
fill_
(
0
)
grad_input
=
grad_input
.
add
(
-
eps
/
norm
)
ctx
.
grad_input
=
grad_input
if
reduce
:
return
grad_input
.
view
(
-
1
).
dot
(
input
.
view
(
-
1
))
else
:
return
grad_input
*
input
@
staticmethod
def
backward
(
ctx
,
grad
):
grad_input
=
ctx
.
grad_input
if
not
isinstance
(
grad_input
,
torch
.
autograd
.
Variable
):
grad_input
=
utils
.
volatile_variable
(
grad_input
)
return
grad_input
*
grad
,
None
,
None
,
None
,
None
,
None
@
register_criterion
(
'label_smoothed_cross_entropy'
)
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
...
...
@@ -73,10 +37,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""
net_output
=
model
(
**
sample
[
'net_input'
])
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
lprobs
.
view
(
-
1
,
lprobs
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
lprobs
,
target
,
self
.
eps
,
self
.
padding_idx
,
None
,
reduce
)
nll_loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
target
=
sample
[
'target'
].
unsqueeze
(
-
1
)
non_pad_mask
=
target
.
ne
(
self
.
padding_idx
)
nll_loss
=
-
lprobs
.
gather
(
dim
=-
1
,
index
=
target
)[
non_pad_mask
]
smooth_loss
=
-
lprobs
.
sum
(
dim
=-
1
,
keepdim
=
True
)[
non_pad_mask
]
if
reduce
:
nll_loss
=
nll_loss
.
sum
()
smooth_loss
=
smooth_loss
.
sum
()
eps_i
=
self
.
eps
/
lprobs
.
size
(
-
1
)
loss
=
(
1.
-
self
.
eps
)
*
nll_loss
+
eps_i
*
smooth_loss
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
...
...
tests/test_label_smoothing.py
View file @
e73fddf4
...
...
@@ -4,31 +4,98 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import
torch
import
argparse
import
copy
import
unittest
from
fairseq.criterions.label_smoothed_cross_entropy
import
LabelSmoothedNLLLoss
from
torch.autograd
import
Variable
,
gradcheck
import
torch
from
torch.autograd
import
Variable
from
fairseq
import
utils
from
fairseq.criterions.cross_entropy
import
CrossEntropyCriterion
from
fairseq.criterions.label_smoothed_cross_entropy
import
LabelSmoothedCrossEntropyCriterion
torch
.
set_default_tensor_type
(
'torch.DoubleTensor'
)
import
tests.utils
as
test_utils
class
TestLabelSmoothing
(
unittest
.
TestCase
):
def
test_label_smoothing
(
self
):
input
=
Variable
(
torch
.
randn
(
3
,
5
),
requires_grad
=
True
)
idx
=
torch
.
rand
(
3
)
*
4
target
=
Variable
(
idx
.
long
())
criterion
=
LabelSmoothedNLLLoss
()
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
2
,
None
),
(
input
,
target
)
))
weights
=
torch
.
ones
(
5
)
weights
[
2
]
=
0
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
None
,
weights
),
(
input
,
target
)))
self
.
assertTrue
(
gradcheck
(
lambda
x
,
y
:
criterion
.
apply
(
x
,
y
,
0.1
,
None
,
None
),
(
input
,
target
)))
def
setUp
(
self
):
# build dictionary
self
.
d
=
test_utils
.
dummy_dictionary
(
3
)
vocab
=
len
(
self
.
d
)
self
.
assertEqual
(
vocab
,
4
+
3
)
# 4 special + 3 tokens
self
.
assertEqual
(
self
.
d
.
pad
(),
1
)
self
.
assertEqual
(
self
.
d
.
eos
(),
2
)
self
.
assertEqual
(
self
.
d
.
unk
(),
3
)
pad
,
eos
,
unk
,
w1
,
w2
,
w3
=
1
,
2
,
3
,
4
,
5
,
6
# build dataset
self
.
data
=
[
# the first batch item has padding
{
'source'
:
torch
.
LongTensor
([
w1
,
eos
]),
'target'
:
torch
.
LongTensor
([
w1
,
eos
])},
{
'source'
:
torch
.
LongTensor
([
w1
,
eos
]),
'target'
:
torch
.
LongTensor
([
w1
,
w1
,
eos
])},
]
self
.
sample
=
next
(
test_utils
.
dummy_dataloader
(
self
.
data
))
# build model
self
.
args
=
argparse
.
Namespace
()
self
.
args
.
sentence_avg
=
False
self
.
args
.
probs
=
torch
.
FloatTensor
([
# pad eos unk w1 w2 w3
[
0.05
,
0.05
,
0.1
,
0.05
,
0.3
,
0.4
,
0.05
],
[
0.05
,
0.10
,
0.2
,
0.05
,
0.2
,
0.3
,
0.10
],
[
0.05
,
0.15
,
0.3
,
0.05
,
0.1
,
0.2
,
0.15
],
]).
unsqueeze
(
0
).
expand
(
2
,
3
,
7
)
# add batch dimension
self
.
model
=
test_utils
.
TestModel
.
build_model
(
self
.
args
,
self
.
d
,
self
.
d
)
def
test_nll_loss
(
self
):
self
.
args
.
label_smoothing
=
0.1
nll_crit
=
CrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
smooth_crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
nll_loss
,
nll_sample_size
,
nll_logging_output
=
nll_crit
(
self
.
model
,
self
.
sample
)
smooth_loss
,
smooth_sample_size
,
smooth_logging_output
=
smooth_crit
(
self
.
model
,
self
.
sample
)
self
.
assertLess
(
abs
(
nll_loss
-
nll_logging_output
[
'loss'
]),
1e-6
)
self
.
assertLess
(
abs
(
nll_loss
-
smooth_logging_output
[
'nll_loss'
]),
1e-6
)
def
test_padding
(
self
):
self
.
args
.
label_smoothing
=
0.1
crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
loss
,
_
,
logging_output
=
crit
(
self
.
model
,
self
.
sample
)
def
get_one_no_padding
(
idx
):
# create a new sample with just a single batch item so that there's
# no padding
sample1
=
next
(
test_utils
.
dummy_dataloader
([
self
.
data
[
idx
]]))
args1
=
copy
.
copy
(
self
.
args
)
args1
.
probs
=
args1
.
probs
[
idx
,
:,
:].
unsqueeze
(
0
)
model1
=
test_utils
.
TestModel
.
build_model
(
args1
,
self
.
d
,
self
.
d
)
loss1
,
_
,
_
=
crit
(
model1
,
sample1
)
return
loss1
loss1
=
get_one_no_padding
(
0
)
loss2
=
get_one_no_padding
(
1
)
self
.
assertAlmostEqual
(
loss
,
loss1
+
loss2
)
def
test_reduction
(
self
):
self
.
args
.
label_smoothing
=
0.1
crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
loss
,
_
,
logging_output
=
crit
(
self
.
model
,
self
.
sample
,
reduce
=
True
)
unreduced_loss
,
_
,
_
=
crit
(
self
.
model
,
self
.
sample
,
reduce
=
False
)
self
.
assertAlmostEqual
(
loss
,
unreduced_loss
.
sum
())
def
test_zero_eps
(
self
):
self
.
args
.
label_smoothing
=
0.0
nll_crit
=
CrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
smooth_crit
=
LabelSmoothedCrossEntropyCriterion
(
self
.
args
,
self
.
d
,
self
.
d
)
nll_loss
,
nll_sample_size
,
nll_logging_output
=
nll_crit
(
self
.
model
,
self
.
sample
)
smooth_loss
,
smooth_sample_size
,
smooth_logging_output
=
smooth_crit
(
self
.
model
,
self
.
sample
)
self
.
assertAlmostEqual
(
nll_loss
,
smooth_loss
)
def
assertAlmostEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertLess
((
t1
-
t2
).
abs
().
max
(),
1e-6
)
if
__name__
==
'__main__'
:
...
...
tests/utils.py
View file @
e73fddf4
...
...
@@ -92,7 +92,7 @@ class TestEncoder(FairseqEncoder):
class
TestIncrementalDecoder
(
FairseqIncrementalDecoder
):
def
__init__
(
self
,
args
,
dictionary
):
super
().
__init__
(
dictionary
)
assert
hasattr
(
args
,
'beam_probs'
)
assert
hasattr
(
args
,
'beam_probs'
)
or
hasattr
(
args
,
'probs'
)
args
.
max_decoder_positions
=
getattr
(
args
,
'max_decoder_positions'
,
100
)
self
.
args
=
args
...
...
@@ -116,6 +116,11 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
steps
=
list
(
range
(
tgt_len
))
# define output in terms of raw probs
if
hasattr
(
self
.
args
,
'probs'
):
assert
self
.
args
.
probs
.
dim
()
==
3
,
\
'expected probs to have size bsz*steps*vocab'
probs
=
self
.
args
.
probs
.
index_select
(
1
,
torch
.
LongTensor
(
steps
))
else
:
probs
=
torch
.
FloatTensor
(
bbsz
,
len
(
steps
),
vocab
).
zero_
()
for
i
,
step
in
enumerate
(
steps
):
# args.beam_probs gives the probability for every vocab element,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment