Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
4db6579a
"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "7c465d20fc32aa9e1a632034c92a398e57aee107"
Commit
4db6579a
authored
Jan 07, 2018
by
Myle Ott
Browse files
Move normalization of model output (e.g., via LSM) into model definition
parent
c21a6e29
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
21 additions
and
7 deletions
+21
-7
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+3
-3
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+2
-2
fairseq/models/fairseq_decoder.py
fairseq/models/fairseq_decoder.py
+10
-0
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+1
-1
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+4
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+1
-1
No files found.
fairseq/criterions/cross_entropy.py
View file @
4db6579a
...
@@ -26,10 +26,10 @@ class CrossEntropyCriterion(FairseqCriterion):
...
@@ -26,10 +26,10 @@ class CrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
3) logging outputs to display while training
"""
"""
net_output
=
model
(
**
sample
[
'net_input'
])
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)
)
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
F
.
cross_entropy
(
input
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
loss
=
F
.
nll_loss
(
lprobs
,
target
,
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
logging_output
=
{
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
4db6579a
...
@@ -62,9 +62,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...
@@ -62,9 +62,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
3) logging outputs to display while training
"""
"""
net_output
=
model
(
**
sample
[
'net_input'
])
net_output
=
model
(
**
sample
[
'net_input'
])
input
=
F
.
log_softmax
(
net_output
.
view
(
-
1
,
net_output
.
size
(
-
1
)),
dim
=
1
)
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
target
=
sample
[
'target'
].
view
(
-
1
)
target
=
sample
[
'target'
].
view
(
-
1
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
input
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
,
reduce
)
loss
=
LabelSmoothedNLLLoss
.
apply
(
lprobs
,
target
,
self
.
eps
,
self
.
padding_idx
,
self
.
weights
,
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
logging_output
=
{
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
'loss'
:
loss
.
data
[
0
]
if
reduce
else
loss
.
data
,
...
...
fairseq/models/fairseq_decoder.py
View file @
4db6579a
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#
#
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
FairseqDecoder
(
nn
.
Module
):
class
FairseqDecoder
(
nn
.
Module
):
...
@@ -15,6 +16,15 @@ class FairseqDecoder(nn.Module):
...
@@ -15,6 +16,15 @@ class FairseqDecoder(nn.Module):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
def
get_normalized_probs
(
self
,
net_output
,
log_probs
):
"""Get normalized probabilities (or log probs) from a net's output."""
vocab
=
net_output
.
size
(
-
1
)
net_output1
=
net_output
.
view
(
-
1
,
vocab
)
if
log_probs
:
return
F
.
log_softmax
(
net_output1
,
dim
=
1
).
view_as
(
net_output
)
else
:
return
F
.
softmax
(
net_output1
,
dim
=
1
).
view_as
(
net_output
)
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum input length supported by the decoder."""
"""Maximum input length supported by the decoder."""
raise
NotImplementedError
raise
NotImplementedError
...
...
fairseq/models/fairseq_incremental_decoder.py
View file @
4db6579a
...
@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
...
@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
with model.decoder.incremental_inference():
with model.decoder.incremental_inference():
for step in range(maxlen):
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs =
torch.nn.functional.log_softmax(out[:, -1, :]
)
probs =
model.get_normalized_probs(out[:, -1, :], log_probs=False
)
```
```
"""
"""
class
IncrementalInference
(
object
):
class
IncrementalInference
(
object
):
...
...
fairseq/models/fairseq_model.py
View file @
4db6579a
...
@@ -35,6 +35,10 @@ class FairseqModel(nn.Module):
...
@@ -35,6 +35,10 @@ class FairseqModel(nn.Module):
decoder_out
,
_
=
self
.
decoder
(
input_tokens
,
encoder_out
)
decoder_out
,
_
=
self
.
decoder
(
input_tokens
,
encoder_out
)
return
decoder_out
.
view
(
-
1
,
decoder_out
.
size
(
-
1
))
return
decoder_out
.
view
(
-
1
,
decoder_out
.
size
(
-
1
))
def
get_normalized_probs
(
self
,
net_output
,
log_probs
):
"""Get normalized probabilities (or log probs) from a net's output."""
return
self
.
decoder
.
get_normalized_probs
(
net_output
,
log_probs
)
def
max_encoder_positions
(
self
):
def
max_encoder_positions
(
self
):
"""Maximum input length supported by the encoder."""
"""Maximum input length supported by the encoder."""
return
self
.
encoder
.
max_positions
()
return
self
.
encoder
.
max_positions
()
...
...
fairseq/sequence_generator.py
View file @
4db6579a
...
@@ -328,7 +328,7 @@ class SequenceGenerator(object):
...
@@ -328,7 +328,7 @@ class SequenceGenerator(object):
avg_attn
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
probs
=
F
.
softmax
(
decoder_out
[:,
-
1
,
:],
dim
=
1
).
data
probs
=
model
.
get_normalized_probs
(
decoder_out
[:,
-
1
,
:],
log_probs
=
False
).
data
if
avg_probs
is
None
:
if
avg_probs
is
None
:
avg_probs
=
probs
avg_probs
=
probs
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment