Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
bfeb7732
You need to sign in or sign up before continuing.
Commit
bfeb7732
authored
Sep 08, 2018
by
Stephen Roller
Committed by
Myle Ott
Sep 25, 2018
Browse files
Pass encoder_input to generator, rather than src_tokens/src_lengths.
parent
8bd8ec8f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
20 deletions
+41
-20
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+27
-11
interactive.py
interactive.py
+2
-2
tests/test_sequence_generator.py
tests/test_sequence_generator.py
+12
-7
No files found.
fairseq/sequence_generator.py
View file @
bfeb7732
...
@@ -78,13 +78,18 @@ class SequenceGenerator(object):
...
@@ -78,13 +78,18 @@ class SequenceGenerator(object):
if
'net_input'
not
in
s
:
if
'net_input'
not
in
s
:
continue
continue
input
=
s
[
'net_input'
]
input
=
s
[
'net_input'
]
srclen
=
input
[
'src_tokens'
].
size
(
1
)
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input
=
{
k
:
v
for
k
,
v
in
input
.
items
()
if
k
!=
'prev_output_tokens'
}
srclen
=
encoder_input
[
'src_tokens'
].
size
(
1
)
if
timer
is
not
None
:
if
timer
is
not
None
:
timer
.
start
()
timer
.
start
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
hypos
=
self
.
generate
(
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
encoder_input
,
input
[
'src_lengths'
],
beam_size
=
beam_size
,
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
...
@@ -97,12 +102,23 @@ class SequenceGenerator(object):
...
@@ -97,12 +102,23 @@ class SequenceGenerator(object):
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
yield
id
,
src
,
ref
,
hypos
[
i
]
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
def
generate
(
self
,
encoder_input
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""Generate a batch of translations."""
"""Generate a batch of translations.
Args:
encoder_input: dictionary containing the inputs to
model.encoder.forward
beam_size: int overriding the beam size. defaults to
self.beam_size
max_len: maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
return
self
.
_generate
(
encoder_input
,
beam_size
,
maxlen
,
prefix_tokens
)
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
def
_generate
(
self
,
encoder_input
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""See generate"""
src_tokens
=
encoder_input
[
'src_tokens'
]
bsz
,
srclen
=
src_tokens
.
size
()
bsz
,
srclen
=
src_tokens
.
size
()
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
...
@@ -121,10 +137,10 @@ class SequenceGenerator(object):
...
@@ -121,10 +137,10 @@ class SequenceGenerator(object):
incremental_states
[
model
]
=
None
incremental_states
[
model
]
=
None
# compute the encoder output for each beam
# compute the encoder output for each beam
encoder_out
=
model
.
encoder
(
encoder_out
=
model
.
encoder
(
**
encoder_input
)
src_tokens
.
repeat
(
1
,
beam_size
).
view
(
-
1
,
srclen
),
new_order
=
torch
.
arange
(
bsz
).
view
(
-
1
,
1
)
.
repeat
(
1
,
beam_size
).
view
(
-
1
)
src_lengths
.
expand
(
beam_size
,
src_lengths
.
numel
()).
t
().
contiguous
().
view
(
-
1
),
new_order
=
new_order
.
to
(
src_tokens
.
device
)
)
encoder_out
=
model
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
encoder_outs
.
append
(
encoder_out
)
encoder_outs
.
append
(
encoder_out
)
# initialize buffers
# initialize buffers
...
...
interactive.py
View file @
bfeb7732
...
@@ -145,9 +145,9 @@ def main(args):
...
@@ -145,9 +145,9 @@ def main(args):
tokens
=
tokens
.
cuda
()
tokens
=
tokens
.
cuda
()
lengths
=
lengths
.
cuda
()
lengths
=
lengths
.
cuda
()
encoder_input
=
{
'src_tokens'
:
tokens
,
'src_lengths'
:
lengths
}
translations
=
translator
.
generate
(
translations
=
translator
.
generate
(
tokens
,
encoder_input
,
lengths
,
maxlen
=
int
(
args
.
max_len_a
*
tokens
.
size
(
1
)
+
args
.
max_len_b
),
maxlen
=
int
(
args
.
max_len_a
*
tokens
.
size
(
1
)
+
args
.
max_len_b
),
)
)
...
...
tests/test_sequence_generator.py
View file @
bfeb7732
...
@@ -33,6 +33,10 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -33,6 +33,10 @@ class TestSequenceGenerator(unittest.TestCase):
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
])
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
self
.
encoder_input
=
{
'src_tokens'
:
self
.
src_tokens
,
'src_lengths'
:
self
.
src_lengths
,
}
args
=
argparse
.
Namespace
()
args
=
argparse
.
Namespace
()
unk
=
0.
unk
=
0.
...
@@ -85,7 +89,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -85,7 +89,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_normalization
(
self
):
def
test_with_normalization
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -104,7 +108,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -104,7 +108,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 1: unchanged from the normalized case
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
# Sentence 2: beams swap order
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -122,7 +126,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -122,7 +126,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_lenpen_favoring_short_hypos
(
self
):
def
test_with_lenpen_favoring_short_hypos
(
self
):
lenpen
=
0.6
lenpen
=
0.6
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -140,7 +144,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -140,7 +144,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_lenpen_favoring_long_hypos
(
self
):
def
test_with_lenpen_favoring_long_hypos
(
self
):
lenpen
=
5.0
lenpen
=
5.0
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
...
@@ -157,7 +161,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -157,7 +161,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_maxlen
(
self
):
def
test_maxlen
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -174,7 +178,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -174,7 +178,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_no_stop_early
(
self
):
def
test_no_stop_early
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
encoder_input
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -273,7 +277,8 @@ class TestDiverseBeamSearch(unittest.TestCase):
...
@@ -273,7 +277,8 @@ class TestDiverseBeamSearch(unittest.TestCase):
[
self
.
model
],
self
.
tgt_dict
,
[
self
.
model
],
self
.
tgt_dict
,
beam_size
=
2
,
diverse_beam_groups
=
2
,
diverse_beam_strength
=
0.
,
beam_size
=
2
,
diverse_beam_groups
=
2
,
diverse_beam_strength
=
0.
,
)
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
)
encoder_input
=
{
'src_tokens'
:
self
.
src_tokens
,
'src_lengths'
:
self
.
src_lengths
}
hypos
=
generator
.
generate
(
encoder_input
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
w1
,
eos
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment