Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
6f6cb4ab
Commit
6f6cb4ab
authored
Jan 01, 2018
by
Myle Ott
Browse files
Add reduce kwarg to criterions
parent
dcbf5e75
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
10 deletions
+14
-10
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+4
-3
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+1
-1
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+9
-6
No files found.
fairseq/criterions/cross_entropy.py
View file @
6f6cb4ab
...
...
@@ -17,7 +17,7 @@ class CrossEntropyCriterion(FairseqCriterion):
def
__init__
(
self
,
args
,
dst_dict
):
super
().
__init__
(
args
,
dst_dict
)
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
...
...
@@ -28,10 +28,11 @@ class CrossEntropyCriterion(FairseqCriterion):
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
))
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
...
...
fairseq/criterions/fairseq_criterion.py
View file @
6f6cb4ab
...
...
@@ -16,7 +16,7 @@ class FairseqCriterion(_Loss):
self
.
args
=
args
self
.
padding_idx
=
dst_dict
.
pad
()
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
6f6cb4ab
...
...
@@ -17,7 +17,7 @@ from .fairseq_criterion import FairseqCriterion
class
LabelSmoothedNLLLoss
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
target
,
eps
,
padding_idx
,
weights
):
def
forward
(
ctx
,
input
,
target
,
eps
,
padding_idx
,
weights
,
reduce
=
True
):
grad_input
=
input
.
new
(
input
.
size
()).
zero_
()
target
=
target
.
view
(
target
.
size
(
0
),
1
)
grad_input
=
grad_input
.
scatter_
(
grad_input
.
dim
()
-
1
,
target
,
eps
-
1
)
...
...
@@ -34,11 +34,14 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
grad_input
=
grad_input
.
add
(
-
eps
/
norm
)
ctx
.
grad_input
=
grad_input
if
reduce
:
return
input
.
new
([
grad_input
.
view
(
-
1
).
dot
(
input
.
view
(
-
1
))])
else
:
return
grad_input
*
input
@
staticmethod
def
backward
(
ctx
,
grad
):
return
Variable
(
ctx
.
grad_input
,
volatile
=
True
)
*
grad
,
None
,
None
,
None
,
None
return
Variable
(
ctx
.
grad_input
,
volatile
=
True
)
*
grad
,
None
,
None
,
None
,
None
,
None
class
LabelSmoothedCrossEntropyCriterion
(
FairseqCriterion
):
...
...
@@ -48,7 +51,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self
.
eps
=
args
.
label_smoothing
self
.
weights
=
weights
def
forward
(
self
,
model
,
sample
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
...
...
@@ -59,10 +62,10 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)),
dim
=
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
,
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
loss
.
data
[
0
],
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment