Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
bfeb7732
Commit
bfeb7732
authored
Sep 08, 2018
by
Stephen Roller
Committed by
Myle Ott
Sep 25, 2018
Browse files
Pass encoder_input to generator, rather than src_tokens/src_lengths.
parent
8bd8ec8f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
20 deletions
+41
-20
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+27
-11
interactive.py
interactive.py
+2
-2
tests/test_sequence_generator.py
tests/test_sequence_generator.py
+12
-7
No files found.
fairseq/sequence_generator.py
View file @
bfeb7732
...
@@ -78,13 +78,18 @@ class SequenceGenerator(object):
...
@@ -78,13 +78,18 @@ class SequenceGenerator(object):
if
'net_input'
not
in
s
:
if
'net_input'
not
in
s
:
continue
continue
input
=
s
[
'net_input'
]
input
=
s
[
'net_input'
]
srclen
=
input
[
'src_tokens'
].
size
(
1
)
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input
=
{
k
:
v
for
k
,
v
in
input
.
items
()
if
k
!=
'prev_output_tokens'
}
srclen
=
encoder_input
[
'src_tokens'
].
size
(
1
)
if
timer
is
not
None
:
if
timer
is
not
None
:
timer
.
start
()
timer
.
start
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
hypos
=
self
.
generate
(
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
encoder_input
,
input
[
'src_lengths'
],
beam_size
=
beam_size
,
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
...
@@ -97,12 +102,23 @@ class SequenceGenerator(object):
...
@@ -97,12 +102,23 @@ class SequenceGenerator(object):
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
yield
id
,
src
,
ref
,
hypos
[
i
]
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
def
generate
(
self
,
encoder_input
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""Generate a batch of translations."""
"""Generate a batch of translations.
Args:
encoder_input: dictionary containing the inputs to
model.encoder.forward
beam_size: int overriding the beam size. defaults to
self.beam_size
max_len: maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
return
self
.
_generate
(
encoder_input
,
beam_size
,
maxlen
,
prefix_tokens
)
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
def
_generate
(
self
,
encoder_input
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""See generate"""
src_tokens
=
encoder_input
[
'src_tokens'
]
bsz
,
srclen
=
src_tokens
.
size
()
bsz
,
srclen
=
src_tokens
.
size
()
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
...
@@ -121,10 +137,10 @@ class SequenceGenerator(object):
...
@@ -121,10 +137,10 @@ class SequenceGenerator(object):
incremental_states
[
model
]
=
None
incremental_states
[
model
]
=
None
# compute the encoder output for each beam
# compute the encoder output for each beam
encoder_out
=
model
.
encoder
(
encoder_out
=
model
.
encoder
(
**
encoder_input
)
src_tokens
.
repeat
(
1
,
beam_size
).
view
(
-
1
,
srclen
),
new_order
=
torch
.
arange
(
bsz
).
view
(
-
1
,
1
)
.
repeat
(
1
,
beam_size
).
view
(
-
1
)
src_lengths
.
expand
(
beam_size
,
src_lengths
.
numel
()).
t
().
contiguous
().
view
(
-
1
),
new_order
=
new_order
.
to
(
src_tokens
.
device
)
)
encoder_out
=
model
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
encoder_outs
.
append
(
encoder_out
)
encoder_outs
.
append
(
encoder_out
)
# initialize buffers
# initialize buffers
...
...
interactive.py
View file @
bfeb7732
...
@@ -145,9 +145,9 @@ def main(args):
...
@@ -145,9 +145,9 @@ def main(args):
tokens
=
tokens
.
cuda
()
tokens
=
tokens
.
cuda
()
lengths
=
lengths
.
cuda
()
lengths
=
lengths
.
cuda
()
encoder_input
=
{
'src_tokens'
:
tokens
,
'src_lengths'
:
lengths
}
translations
=
translator
.
generate
(
translations
=
translator
.
generate
(
tokens
,
encoder_input
,
lengths
,
maxlen
=
int
(
args
.
max_len_a
*
tokens
.
size
(
1
)
+
args
.
max_len_b
),
maxlen
=
int
(
args
.
max_len_a
*
tokens
.
size
(
1
)
+
args
.
max_len_b
),
)
)
...
...
tests/test_sequence_generator.py
View file @
bfeb7732
...
@@ -33,6 +33,10 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -33,6 +33,10 @@ class TestSequenceGenerator(unittest.TestCase):
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
])
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
self
.
encoder_input
=
{
'src_tokens'
:
self
.
src_tokens
,
'src_lengths'
:
self
.
src_lengths
,
}
args
=
argparse
.
Namespace
()
args
=
argparse
.
Namespace
()
unk
=
0.
unk
=
0.
...
@@ -85,7 +89,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -85,7 +89,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_normalization
(
self
):
def
test_with_normalization
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -104,7 +108,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -104,7 +108,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 1: unchanged from the normalized case
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
# Sentence 2: beams swap order
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -122,7 +126,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -122,7 +126,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_lenpen_favoring_short_hypos
(
self
):
def
test_with_lenpen_favoring_short_hypos
(
self
):
lenpen
=
0.6
lenpen
=
0.6
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -140,7 +144,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -140,7 +144,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_lenpen_favoring_long_hypos
(
self
):
def
test_with_lenpen_favoring_long_hypos
(
self
):
lenpen
=
5.0
lenpen
=
5.0
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
...
@@ -157,7 +161,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -157,7 +161,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_maxlen
(
self
):
def
test_maxlen
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -174,7 +178,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -174,7 +178,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_no_stop_early
(
self
):
def
test_no_stop_early
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -273,7 +277,8 @@ class TestDiverseBeamSearch(unittest.TestCase):
...
@@ -273,7 +277,8 @@ class TestDiverseBeamSearch(unittest.TestCase):
[
self
.
model
],
self
.
tgt_dict
,
[
self
.
model
],
self
.
tgt_dict
,
beam_size
=
2
,
diverse_beam_groups
=
2
,
diverse_beam_strength
=
0.
,
beam_size
=
2
,
diverse_beam_groups
=
2
,
diverse_beam_strength
=
0.
,
)
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
)
encoder_input
=
{
'src_tokens'
:
self
.
src_tokens
,
'src_lengths'
:
self
.
src_lengths
}
hypos
=
generator
.
generate
(
encoder_input
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
w1
,
eos
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment