Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
ef43da72
Commit
ef43da72
authored
Aug 09, 2018
by
Myle Ott
Browse files
Factor out search logic in SequenceGenerator
parent
75e12a27
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
210 additions
and
70 deletions
+210
-70
fairseq/search.py
fairseq/search.py
+165
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+24
-70
generate.py
generate.py
+1
-0
tests/test_binaries.py
tests/test_binaries.py
+20
-0
No files found.
fairseq/search.py
0 → 100644
View file @
ef43da72
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch
class
Search
(
object
):
def
__init__
(
self
,
tgt_dict
):
self
.
pad
=
tgt_dict
.
pad
()
self
.
unk
=
tgt_dict
.
unk
()
self
.
eos
=
tgt_dict
.
eos
()
self
.
vocab_size
=
len
(
tgt_dict
)
self
.
scores_buf
=
None
self
.
indices_buf
=
None
self
.
beams_buf
=
None
def
_init_buffers
(
self
,
t
):
if
self
.
scores_buf
is
None
:
self
.
scores_buf
=
t
.
new
()
self
.
indices_buf
=
torch
.
LongTensor
().
to
(
device
=
t
.
device
)
self
.
beams_buf
=
torch
.
LongTensor
().
to
(
device
=
t
.
device
)
def
step
(
self
,
step
,
lprobs
,
scores
,
beam_size
):
"""Take a single search step.
Args:
step: the current search step, starting at 0
lprobs: (bsz x input_beam_size x vocab_size)
the model's log-probabilities over the vocabulary at the current step
scores: (bsz x input_beam_size x step)
the historical model scores of each hypothesis up to this point
Return: A tuple of (scores, indices, beams) where:
scores: (bsz x output_beam_size)
the scores of the chosen elements; output_beam_size can be
larger than input_beam_size, e.g., we may return
2*input_beam_size to account for EOS
indices: (bsz x output_beam_size)
the indices of the chosen elements
beams: (bsz x output_beam_size)
the hypothesis ids of the chosen elements, in the range [0, input_beam_size)
"""
raise
NotImplementedError
class
BeamSearch
(
Search
):
def
__init__
(
self
,
tgt_dict
):
super
().
__init__
(
tgt_dict
)
def
step
(
self
,
step
,
lprobs
,
scores
):
super
().
_init_buffers
(
lprobs
)
bsz
,
beam_size
,
vocab_size
=
lprobs
.
size
()
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs
=
lprobs
[:,
::
beam_size
,
:].
contiguous
()
else
:
# make probs contain cumulative scores for each hypothesis
lprobs
.
add_
(
scores
[:,
:,
step
-
1
].
unsqueeze
(
-
1
))
torch
.
topk
(
lprobs
.
view
(
bsz
,
-
1
),
k
=
min
(
# Take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
beam_size
*
2
,
lprobs
.
view
(
bsz
,
-
1
).
size
(
1
)
-
1
,
# -1 so we never select pad
),
out
=
(
self
.
scores_buf
,
self
.
indices_buf
),
)
torch
.
div
(
self
.
indices_buf
,
vocab_size
,
out
=
self
.
beams_buf
)
self
.
indices_buf
.
fmod_
(
vocab_size
)
return
self
.
scores_buf
,
self
.
indices_buf
,
self
.
beams_buf
class
Sampling
(
Search
):
def
__init__
(
self
,
tgt_dict
,
sampling_topk
=-
1
,
sampling_temperature
=
1.
):
super
().
__init__
(
tgt_dict
)
self
.
sampling_topk
=
sampling_topk
self
.
sampling_temperature
=
sampling_temperature
def
step
(
self
,
step
,
lprobs
,
scores
):
super
().
_init_buffers
(
lprobs
)
bsz
,
beam_size
,
vocab_size
=
lprobs
.
size
()
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# only the first beam
lprobs
=
lprobs
[:,
::
beam_size
,
:].
contiguous
()
# we exclude the first two vocab items, one of which is pad
assert
self
.
pad
==
1
,
'sampling assumes the first two symbols can be ignored'
lprobs_nopad
=
lprobs
[:,
:,
2
:]
# only sample from top-k candidates
if
self
.
sampling_topk
>
0
:
lprobs_nopad
,
topk_indices
=
lprobs_nopad
.
topk
(
self
.
sampling_topk
)
# sampling temperature
if
self
.
sampling_temperature
!=
1.
:
lprobs_nopad
=
lprobs_nopad
.
div_
(
self
.
sampling_temperature
)
# sample
probs_nopad
=
lprobs_nopad
.
exp_
()
if
step
==
0
:
self
.
indices_buf
=
torch
.
multinomial
(
probs_nopad
.
view
(
bsz
,
-
1
),
beam_size
,
replacement
=
True
,
out
=
self
.
indices_buf
,
).
view
(
bsz
,
beam_size
)
else
:
self
.
indices_buf
=
torch
.
multinomial
(
probs_nopad
.
view
(
bsz
*
beam_size
,
-
1
),
1
,
replacement
=
True
,
out
=
self
.
indices_buf
,
).
view
(
bsz
,
beam_size
)
if
step
==
0
:
# expand to beam size
probs_nopad
=
probs_nopad
.
expand
(
bsz
,
beam_size
,
-
1
)
# gather scores
torch
.
gather
(
probs_nopad
,
dim
=
2
,
index
=
self
.
indices_buf
.
unsqueeze
(
-
1
),
out
=
self
.
scores_buf
,
)
self
.
scores_buf
=
self
.
scores_buf
.
log_
().
view
(
bsz
,
-
1
)
# remap indices if using top-k sampling
if
self
.
sampling_topk
>
0
:
self
.
indices_buf
=
torch
.
gather
(
topk_indices
.
expand
(
bsz
,
beam_size
,
-
1
),
dim
=
2
,
index
=
self
.
indices_buf
.
unsqueeze
(
-
1
),
).
squeeze
(
2
)
# remap indices since we excluded the first two vocab items
self
.
indices_buf
.
add_
(
2
)
if
step
==
0
:
self
.
beams_buf
=
self
.
indices_buf
.
new_zeros
(
bsz
,
beam_size
)
else
:
self
.
beams_buf
=
torch
.
arange
(
0
,
beam_size
,
out
=
self
.
beams_buf
).
repeat
(
bsz
,
1
)
# make scores cumulative
self
.
scores_buf
.
add_
(
torch
.
gather
(
scores
[:,
:,
step
-
1
],
dim
=
1
,
index
=
self
.
beams_buf
,
)
)
return
self
.
scores_buf
,
self
.
indices_buf
,
self
.
beams_buf
fairseq/sequence_generator.py
View file @
ef43da72
...
...
@@ -9,7 +9,7 @@ import math
import
torch
from
fairseq
import
utils
from
fairseq
import
search
,
utils
from
fairseq.models
import
FairseqIncrementalDecoder
...
...
@@ -43,9 +43,13 @@ class SequenceGenerator(object):
self
.
len_penalty
=
len_penalty
self
.
unk_penalty
=
unk_penalty
self
.
retain_dropout
=
retain_dropout
self
.
sampling
=
sampling
self
.
sampling_topk
=
sampling_topk
self
.
sampling_temperature
=
sampling_temperature
assert
sampling_topk
<
0
or
sampling
,
'--sampling-topk requires --sampling'
if
sampling
:
self
.
search
=
search
.
Sampling
(
tgt_dict
,
sampling_topk
,
sampling_temperature
)
else
:
self
.
search
=
search
.
BeamSearch
(
tgt_dict
)
def
cuda
(
self
):
for
model
in
self
.
models
:
...
...
@@ -273,19 +277,10 @@ class SequenceGenerator(object):
model
.
decoder
.
reorder_incremental_state
(
incremental_states
[
model
],
reorder_state
)
encoder_outs
[
i
]
=
model
.
encoder
.
reorder_encoder_out
(
encoder_outs
[
i
],
reorder_state
)
probs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
if
step
==
0
:
# at the first step all hypotheses are equally likely, so use
# only the first beam
probs
=
probs
.
unfold
(
0
,
1
,
beam_size
).
squeeze
(
2
).
contiguous
()
scores
=
scores
.
type_as
(
probs
)
scores_buf
=
scores_buf
.
type_as
(
probs
)
elif
not
self
.
sampling
:
# make probs contain cumulative scores for each hypothesis
probs
.
add_
(
scores
[:,
step
-
1
].
view
(
-
1
,
1
))
lprobs
,
avg_attn_scores
=
self
.
_decode
(
tokens
[:,
:
step
+
1
],
encoder_outs
,
incremental_states
)
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
probs
[:,
self
.
unk
]
-=
self
.
unk_penalty
# apply unk penalty
l
probs
[:,
self
.
pad
]
=
-
math
.
inf
# never select pad
l
probs
[:,
self
.
unk
]
-=
self
.
unk_penalty
# apply unk penalty
# Record attention scores
if
avg_attn_scores
is
not
None
:
...
...
@@ -295,74 +290,33 @@ class SequenceGenerator(object):
nonpad_idxs
=
src_tokens
.
ne
(
self
.
pad
)
attn
[:,
:,
step
+
1
].
copy_
(
avg_attn_scores
)
cand_scores
=
buffer
(
'cand_scores'
,
type_of
=
scores
)
cand_indices
=
buffer
(
'cand_indices'
)
cand_beams
=
buffer
(
'cand_beams'
)
scores
=
scores
.
type_as
(
lprobs
)
scores_buf
=
scores_buf
.
type_as
(
lprobs
)
eos_bbsz_idx
=
buffer
(
'eos_bbsz_idx'
)
eos_scores
=
buffer
(
'eos_scores'
,
type_of
=
scores
)
if
step
<
maxlen
:
if
prefix_tokens
is
not
None
and
step
<
prefix_tokens
.
size
(
1
):
probs_slice
=
probs
.
view
(
bsz
,
-
1
,
probs
.
size
(
-
1
))[:,
0
,
:]
probs_slice
=
l
probs
.
view
(
bsz
,
-
1
,
l
probs
.
size
(
-
1
))[:,
0
,
:]
cand_scores
=
torch
.
gather
(
probs_slice
,
dim
=
1
,
index
=
prefix_tokens
[:,
step
].
view
(
-
1
,
1
).
data
).
expand
(
-
1
,
cand_size
)
cand_indices
=
prefix_tokens
[:,
step
].
view
(
-
1
,
1
).
expand
(
bsz
,
cand_size
).
data
cand_beams
.
resize_as_
(
cand_indices
).
fill_
(
0
)
elif
self
.
sampling
:
assert
self
.
pad
==
1
,
'sampling assumes the first two symbols can be ignored'
if
self
.
sampling_topk
>
0
:
values
,
indices
=
probs
[:,
2
:].
topk
(
self
.
sampling_topk
)
exp_probs
=
values
.
div_
(
self
.
sampling_temperature
).
exp
()
if
step
==
0
:
torch
.
multinomial
(
exp_probs
,
beam_size
,
replacement
=
True
,
out
=
cand_indices
)
else
:
torch
.
multinomial
(
exp_probs
,
1
,
replacement
=
True
,
out
=
cand_indices
)
torch
.
gather
(
exp_probs
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_scores
)
torch
.
gather
(
indices
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_indices
)
cand_indices
.
add_
(
2
)
cand_beams
=
torch
.
zeros_like
(
cand_indices
)
else
:
exp_probs
=
probs
.
div_
(
self
.
sampling_temperature
).
exp_
().
view
(
-
1
,
self
.
vocab_size
)
if
step
==
0
:
# we exclude the first two vocab items, one of which is pad
torch
.
multinomial
(
exp_probs
[:,
2
:],
beam_size
,
replacement
=
True
,
out
=
cand_indices
)
else
:
torch
.
multinomial
(
exp_probs
[:,
2
:],
1
,
replacement
=
True
,
out
=
cand_indices
)
cand_indices
.
add_
(
2
)
torch
.
gather
(
exp_probs
,
dim
=
1
,
index
=
cand_indices
,
out
=
cand_scores
)
cand_scores
.
log_
()
cand_indices
=
cand_indices
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
cand_scores
=
cand_scores
.
view
(
bsz
,
-
1
).
repeat
(
1
,
2
)
if
step
==
0
:
cand_beams
=
torch
.
zeros
(
bsz
,
cand_size
).
type_as
(
cand_indices
)
else
:
cand_beams
=
torch
.
arange
(
0
,
beam_size
).
repeat
(
bsz
,
2
).
type_as
(
cand_indices
)
# make scores cumulative
cand_scores
.
add_
(
torch
.
gather
(
scores
[:,
step
-
1
].
view
(
bsz
,
beam_size
),
dim
=
1
,
index
=
cand_beams
,
)
)
else
:
# take the best 2 x beam_size predictions. We'll choose the first
# beam_size of these which don't predict eos to continue with.
torch
.
topk
(
probs
.
view
(
bsz
,
-
1
),
k
=
min
(
cand_size
,
probs
.
view
(
bsz
,
-
1
).
size
(
1
)
-
1
),
# -1 so we never select pad
out
=
(
cand_scores
,
cand_indices
),
cand_scores
,
cand_indices
,
cand_beams
=
self
.
search
.
step
(
step
,
lprobs
.
view
(
bsz
,
-
1
,
self
.
vocab_size
),
scores
.
view
(
bsz
,
beam_size
,
-
1
)[:,
:,
:
step
],
)
torch
.
div
(
cand_indices
,
self
.
vocab_size
,
out
=
cand_beams
)
cand_indices
.
fmod_
(
self
.
vocab_size
)
else
:
# make probs contain cumulative scores for each hypothesis
lprobs
.
add_
(
scores
[:,
step
-
1
].
unsqueeze
(
-
1
))
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
torch
.
sort
(
probs
[:,
self
.
eos
],
l
probs
[:,
self
.
eos
],
descending
=
True
,
out
=
(
eos_scores
,
eos_bbsz_idx
),
)
...
...
@@ -406,7 +360,7 @@ class SequenceGenerator(object):
new_bsz
=
bsz
-
len
(
finalized_sents
)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask
=
torch
.
ones
(
bsz
).
type_as
(
cand_indices
)
batch_mask
=
cand_indices
.
new_ones
(
bsz
)
batch_mask
[
cand_indices
.
new
(
finalized_sents
)]
=
0
batch_idxs
=
batch_mask
.
nonzero
().
squeeze
(
-
1
)
...
...
generate.py
View file @
ef43da72
...
...
@@ -75,6 +75,7 @@ def main(args):
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
unk_penalty
=
args
.
unkpen
,
sampling
=
args
.
sampling
,
sampling_topk
=
args
.
sampling_topk
,
minlen
=
args
.
min_len
,
sampling_temperature
=
args
.
sampling_temperature
,
)
if
use_cuda
:
...
...
tests/test_binaries.py
View file @
ef43da72
...
...
@@ -58,6 +58,26 @@ class TestTranslation(unittest.TestCase):
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
,
[
'--update-freq'
,
'3'
])
generate_main
(
data_dir
)
def
test_generation
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_sampling'
)
as
data_dir
:
create_dummy_data
(
data_dir
)
preprocess_translation_data
(
data_dir
)
train_translation_model
(
data_dir
,
'fconv_iwslt_de_en'
)
generate_main
(
data_dir
,
[
'--sampling'
,
'--sampling-temperature'
,
'2'
,
'--beam'
,
'2'
,
'--nbest'
,
'2'
,
])
generate_main
(
data_dir
,
[
'--sampling'
,
'--sampling-topk'
,
'3'
,
'--beam'
,
'2'
,
'--nbest'
,
'2'
,
])
generate_main
(
data_dir
,
[
'--prefix-size'
,
'2'
])
def
test_lstm
(
self
):
with
contextlib
.
redirect_stdout
(
StringIO
()):
with
tempfile
.
TemporaryDirectory
(
'test_lstm'
)
as
data_dir
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment