Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
311d2c6c
Commit
311d2c6c
authored
Sep 06, 2018
by
Myle Ott
Browse files
Revert sequence generator changes
parent
0714080b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
15 deletions
+15
-15
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+9
-9
tests/test_sequence_generator.py
tests/test_sequence_generator.py
+6
-6
No files found.
fairseq/sequence_generator.py
View file @
311d2c6c
...
@@ -83,10 +83,11 @@ class SequenceGenerator(object):
...
@@ -83,10 +83,11 @@ class SequenceGenerator(object):
timer
.
start
()
timer
.
start
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
hypos
=
self
.
generate
(
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
input
[
'src_lengths'
],
beam_size
=
beam_size
,
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
),
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
prefix_tokens
=
s
[
'target'
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
**
net_input
,
)
)
if
timer
is
not
None
:
if
timer
is
not
None
:
timer
.
stop
(
sum
(
len
(
h
[
0
][
'tokens'
])
for
h
in
hypos
))
timer
.
stop
(
sum
(
len
(
h
[
0
][
'tokens'
])
for
h
in
hypos
))
...
@@ -96,13 +97,12 @@ class SequenceGenerator(object):
...
@@ -96,13 +97,12 @@ class SequenceGenerator(object):
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
ref
=
utils
.
strip_pad
(
s
[
'target'
].
data
[
i
,
:],
self
.
pad
)
if
s
[
'target'
]
is
not
None
else
None
yield
id
,
src
,
ref
,
hypos
[
i
]
yield
id
,
src
,
ref
,
hypos
[
i
]
def
generate
(
self
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
,
**
net_input
):
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""Generate a batch of translations."""
"""Generate a batch of translations."""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
_generate
(
beam_size
,
maxlen
,
prefix_tokens
,
**
net_input
)
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
def
_generate
(
self
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
,
**
net_input
):
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
src_tokens
=
net_input
[
'src_tokens'
]
bsz
,
srclen
=
src_tokens
.
size
()
bsz
,
srclen
=
src_tokens
.
size
()
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
maxlen
=
min
(
maxlen
,
self
.
maxlen
)
if
maxlen
is
not
None
else
self
.
maxlen
...
@@ -121,10 +121,10 @@ class SequenceGenerator(object):
...
@@ -121,10 +121,10 @@ class SequenceGenerator(object):
incremental_states
[
model
]
=
None
incremental_states
[
model
]
=
None
# compute the encoder output for each beam
# compute the encoder output for each beam
encoder_out
=
model
.
encoder
(
**
net_input
)
encoder_out
=
model
.
encoder
(
new_order
=
torch
.
arange
(
bsz
).
view
(
-
1
,
1
)
.
repeat
(
1
,
beam_size
).
view
(
-
1
)
src_tokens
.
repeat
(
1
,
beam_size
).
view
(
-
1
,
srclen
),
new_order
=
new_order
.
to
(
net_input
[
'src_tokens'
].
device
)
src_lengths
.
expand
(
beam_size
,
src_lengths
.
numel
()).
t
().
contiguous
().
view
(
-
1
),
encoder_out
=
model
.
encoder
.
reorder_encoder_out
(
encoder_out
,
new_order
)
)
encoder_outs
.
append
(
encoder_out
)
encoder_outs
.
append
(
encoder_out
)
# initialize buffers
# initialize buffers
...
...
tests/test_sequence_generator.py
View file @
311d2c6c
...
@@ -85,7 +85,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -85,7 +85,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_normalization
(
self
):
def
test_with_normalization
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
)
hypos
=
generator
.
generate
(
src_tokens
=
self
.
src_tokens
,
src_lengths
=
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -104,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 1: unchanged from the normalized case
# Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order
# Sentence 2: beams swap order
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
normalize_scores
=
False
)
hypos
=
generator
.
generate
(
src_tokens
=
self
.
src_tokens
,
src_lengths
=
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -122,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -122,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_lenpen_favoring_short_hypos
(
self
):
def
test_with_lenpen_favoring_short_hypos
(
self
):
lenpen
=
0.6
lenpen
=
0.6
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
src_tokens
=
self
.
src_tokens
,
src_lengths
=
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -140,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -140,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_with_lenpen_favoring_long_hypos
(
self
):
def
test_with_lenpen_favoring_long_hypos
(
self
):
lenpen
=
5.0
lenpen
=
5.0
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
len_penalty
=
lenpen
)
hypos
=
generator
.
generate
(
src_tokens
=
self
.
src_tokens
,
src_lengths
=
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w2
,
w1
,
w2
,
eos
])
...
@@ -157,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -157,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_maxlen
(
self
):
def
test_maxlen
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
maxlen
=
2
)
hypos
=
generator
.
generate
(
src_tokens
=
self
.
src_tokens
,
src_lengths
=
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
@@ -174,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -174,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase):
def
test_no_stop_early
(
self
):
def
test_no_stop_early
(
self
):
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
generator
=
SequenceGenerator
([
self
.
model
],
self
.
tgt_dict
,
stop_early
=
False
)
hypos
=
generator
.
generate
(
src_tokens
=
self
.
src_tokens
,
src_lengths
=
self
.
src_lengths
,
beam_size
=
2
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
,
beam_size
=
2
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
eos
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment