Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
5fe8ea46
Commit
5fe8ea46
authored
Oct 25, 2017
by
Michael Auli
Committed by
Myle Ott
Nov 08, 2017
Browse files
Added -unkpen flag to generate.py following logic of Lua/Torch version
parent
6e4b7e22
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
2 deletions
+10
-2
fairseq/options.py
fairseq/options.py
+2
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+6
-1
generate.py
generate.py
+2
-1
No files found.
fairseq/options.py
View file @
5fe8ea46
...
...
@@ -112,6 +112,8 @@ def add_generation_args(parser):
help
=
'don
\'
t use BeamableMM in attention layers'
)
group
.
add_argument
(
'--lenpen'
,
default
=
1
,
type
=
float
,
help
=
'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
)
group
.
add_argument
(
'--unkpen'
,
default
=
0
,
type
=
float
,
help
=
'unknown word penalty: <0 produces more unks, >0 produces fewer'
)
group
.
add_argument
(
'--unk-replace-dict'
,
default
=
''
,
type
=
str
,
help
=
'performs unk word replacement'
)
group
.
add_argument
(
'--quiet'
,
action
=
'store_true'
,
...
...
fairseq/sequence_generator.py
View file @
5fe8ea46
...
...
@@ -18,7 +18,8 @@ from fairseq.models import FairseqIncrementalDecoder
class
SequenceGenerator
(
object
):
def
__init__
(
self
,
models
,
beam_size
=
1
,
minlen
=
1
,
maxlen
=
200
,
stop_early
=
True
,
normalize_scores
=
True
,
len_penalty
=
1
):
stop_early
=
True
,
normalize_scores
=
True
,
len_penalty
=
1
,
unk_penalty
=
0
):
"""Generates translations of a given source sentence.
Args:
...
...
@@ -31,8 +32,10 @@ class SequenceGenerator(object):
"""
self
.
models
=
models
self
.
pad
=
models
[
0
].
dst_dict
.
pad
()
self
.
unk
=
models
[
0
].
dst_dict
.
unk
()
self
.
eos
=
models
[
0
].
dst_dict
.
eos
()
assert
all
(
m
.
dst_dict
.
pad
()
==
self
.
pad
for
m
in
self
.
models
[
1
:])
assert
all
(
m
.
dst_dict
.
unk
()
==
self
.
unk
for
m
in
self
.
models
[
1
:])
assert
all
(
m
.
dst_dict
.
eos
()
==
self
.
eos
for
m
in
self
.
models
[
1
:])
self
.
vocab_size
=
len
(
models
[
0
].
dst_dict
)
self
.
beam_size
=
beam_size
...
...
@@ -41,6 +44,7 @@ class SequenceGenerator(object):
self
.
stop_early
=
stop_early
self
.
normalize_scores
=
normalize_scores
self
.
len_penalty
=
len_penalty
self
.
unk_penalty
=
unk_penalty
def
cuda
(
self
):
for
model
in
self
.
models
:
...
...
@@ -230,6 +234,7 @@ class SequenceGenerator(object):
# make probs contain cumulative scores for each hypothesis
probs
.
add_
(
scores
.
view
(
-
1
,
1
))
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
probs
[:,
self
.
unk
]
-=
self
.
unk_penalty
# apply unk penalty
# Record attention scores
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
...
...
generate.py
View file @
5fe8ea46
...
...
@@ -60,7 +60,8 @@ def main():
# Initialize generator
translator
=
SequenceGenerator
(
models
,
beam_size
=
args
.
beam
,
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
)
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
unk_penalty
=
args
.
unkpen
)
if
use_cuda
:
translator
.
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment