Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
8c0ca1a0
Commit
8c0ca1a0
authored
Aug 10, 2018
by
Myle Ott
Browse files
Diverse Beam Search
parent
ba9f32cc
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
178 additions
and
7 deletions
+178
-7
fairseq/options.py
fairseq/options.py
+4
-0
fairseq/search.py
fairseq/search.py
+63
-0
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+4
-0
generate.py
generate.py
+3
-3
interactive.py
interactive.py
+5
-4
tests/test_sequence_generator.py
tests/test_sequence_generator.py
+99
-0
No files found.
fairseq/options.py
View file @
8c0ca1a0
...
@@ -305,6 +305,10 @@ def add_generation_args(parser):
...
@@ -305,6 +305,10 @@ def add_generation_args(parser):
help
=
'sample from top K likely next words instead of all words'
)
help
=
'sample from top K likely next words instead of all words'
)
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
'temperature for random sampling'
)
help
=
'temperature for random sampling'
)
group
.
add_argument
(
'--diverse-beam-groups'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of groups for Diverse Beam Search'
)
group
.
add_argument
(
'--diverse-beam-strength'
,
default
=
0.5
,
type
=
float
,
metavar
=
'N'
,
help
=
'strength of diversity penalty for Diverse Beam Search'
)
group
.
add_argument
(
'--print-alignment'
,
action
=
'store_true'
,
group
.
add_argument
(
'--print-alignment'
,
action
=
'store_true'
,
help
=
'if set, uses attention feedback to compute and print alignment to source tokens'
)
help
=
'if set, uses attention feedback to compute and print alignment to source tokens'
)
group
.
add_argument
(
'--model-overrides'
,
default
=
"{}"
,
type
=
str
,
metavar
=
'DICT'
,
group
.
add_argument
(
'--model-overrides'
,
default
=
"{}"
,
type
=
str
,
metavar
=
'DICT'
,
...
...
fairseq/search.py
View file @
8c0ca1a0
...
@@ -80,6 +80,69 @@ class BeamSearch(Search):
...
@@ -80,6 +80,69 @@ class BeamSearch(Search):
return
self
.
scores_buf
,
self
.
indices_buf
,
self
.
beams_buf
return
self
.
scores_buf
,
self
.
indices_buf
,
self
.
beams_buf
class
DiverseBeamSearch
(
Search
):
"""Diverse Beam Search.
See "Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
Models" for details.
We only implement the Hamming Diversity penalty here, which performed best
in the original paper.
"""
def
__init__
(
self
,
tgt_dict
,
num_groups
,
diversity_strength
):
super
().
__init__
(
tgt_dict
)
self
.
num_groups
=
num_groups
self
.
diversity_strength
=
-
diversity_strength
self
.
diversity_buf
=
None
self
.
beam
=
BeamSearch
(
tgt_dict
)
def
step
(
self
,
step
,
lprobs
,
scores
):
super
().
_init_buffers
(
lprobs
)
bsz
,
beam_size
,
vocab_size
=
lprobs
.
size
()
if
beam_size
%
self
.
num_groups
!=
0
:
raise
ValueError
(
'DiverseBeamSearch requires --beam to be divisible by the number of groups'
)
group_size
=
beam_size
//
self
.
num_groups
# initialize diversity penalty
if
self
.
diversity_buf
is
None
:
self
.
diversity_buf
=
lprobs
.
new
()
torch
.
zeros
(
lprobs
[:,
0
,
:].
size
(),
out
=
self
.
diversity_buf
)
scores_G
,
indices_G
,
beams_G
=
[],
[],
[]
for
g
in
range
(
self
.
num_groups
):
lprobs_g
=
lprobs
[:,
g
::
self
.
num_groups
,
:]
scores_g
=
scores
[:,
g
::
self
.
num_groups
,
:]
if
step
>
0
else
None
# apply diversity penalty
if
g
>
0
:
lprobs_g
=
torch
.
add
(
lprobs_g
,
self
.
diversity_strength
,
self
.
diversity_buf
.
unsqueeze
(
1
))
else
:
lprobs_g
=
lprobs_g
.
contiguous
()
scores_buf
,
indices_buf
,
beams_buf
=
self
.
beam
.
step
(
step
,
lprobs_g
,
scores_g
)
beams_buf
.
mul_
(
self
.
num_groups
).
add_
(
g
)
scores_G
.
append
(
scores_buf
.
clone
())
indices_G
.
append
(
indices_buf
.
clone
())
beams_G
.
append
(
beams_buf
.
clone
())
# update diversity penalty
self
.
diversity_buf
.
scatter_add_
(
1
,
indices_buf
,
self
.
diversity_buf
.
new_ones
(
indices_buf
.
size
())
)
# interleave results from different groups
self
.
scores_buf
=
torch
.
stack
(
scores_G
,
dim
=
2
,
out
=
self
.
scores_buf
).
view
(
bsz
,
-
1
)
self
.
indices_buf
=
torch
.
stack
(
indices_G
,
dim
=
2
,
out
=
self
.
indices_buf
).
view
(
bsz
,
-
1
)
self
.
beams_buf
=
torch
.
stack
(
beams_G
,
dim
=
2
,
out
=
self
.
beams_buf
).
view
(
bsz
,
-
1
)
return
self
.
scores_buf
,
self
.
indices_buf
,
self
.
beams_buf
class
Sampling
(
Search
):
class
Sampling
(
Search
):
def
__init__
(
self
,
tgt_dict
,
sampling_topk
=-
1
,
sampling_temperature
=
1.
):
def
__init__
(
self
,
tgt_dict
,
sampling_topk
=-
1
,
sampling_temperature
=
1.
):
...
...
fairseq/sequence_generator.py
View file @
8c0ca1a0
...
@@ -18,6 +18,7 @@ class SequenceGenerator(object):
...
@@ -18,6 +18,7 @@ class SequenceGenerator(object):
self
,
models
,
tgt_dict
,
beam_size
=
1
,
minlen
=
1
,
maxlen
=
None
,
stop_early
=
True
,
self
,
models
,
tgt_dict
,
beam_size
=
1
,
minlen
=
1
,
maxlen
=
None
,
stop_early
=
True
,
normalize_scores
=
True
,
len_penalty
=
1
,
unk_penalty
=
0
,
retain_dropout
=
False
,
normalize_scores
=
True
,
len_penalty
=
1
,
unk_penalty
=
0
,
retain_dropout
=
False
,
sampling
=
False
,
sampling_topk
=-
1
,
sampling_temperature
=
1
,
sampling
=
False
,
sampling_topk
=-
1
,
sampling_temperature
=
1
,
diverse_beam_groups
=-
1
,
diverse_beam_strength
=
0.5
,
):
):
"""Generates translations of a given source sentence.
"""Generates translations of a given source sentence.
Args:
Args:
...
@@ -48,6 +49,8 @@ class SequenceGenerator(object):
...
@@ -48,6 +49,8 @@ class SequenceGenerator(object):
if
sampling
:
if
sampling
:
self
.
search
=
search
.
Sampling
(
tgt_dict
,
sampling_topk
,
sampling_temperature
)
self
.
search
=
search
.
Sampling
(
tgt_dict
,
sampling_topk
,
sampling_temperature
)
elif
diverse_beam_groups
>
0
:
self
.
search
=
search
.
DiverseBeamSearch
(
tgt_dict
,
diverse_beam_groups
,
diverse_beam_strength
)
else
:
else
:
self
.
search
=
search
.
BeamSearch
(
tgt_dict
)
self
.
search
=
search
.
BeamSearch
(
tgt_dict
)
...
@@ -402,6 +405,7 @@ class SequenceGenerator(object):
...
@@ -402,6 +405,7 @@ class SequenceGenerator(object):
active_mask
,
k
=
beam_size
,
dim
=
1
,
largest
=
False
,
active_mask
,
k
=
beam_size
,
dim
=
1
,
largest
=
False
,
out
=
(
_ignore
,
active_hypos
)
out
=
(
_ignore
,
active_hypos
)
)
)
active_bbsz_idx
=
buffer
(
'active_bbsz_idx'
)
active_bbsz_idx
=
buffer
(
'active_bbsz_idx'
)
torch
.
gather
(
torch
.
gather
(
cand_bbsz_idx
,
dim
=
1
,
index
=
active_hypos
,
cand_bbsz_idx
,
dim
=
1
,
index
=
active_hypos
,
...
...
generate.py
View file @
8c0ca1a0
...
@@ -71,11 +71,11 @@ def main(args):
...
@@ -71,11 +71,11 @@ def main(args):
translator
=
SequenceScorer
(
models
,
task
.
target_dictionary
)
translator
=
SequenceScorer
(
models
,
task
.
target_dictionary
)
else
:
else
:
translator
=
SequenceGenerator
(
translator
=
SequenceGenerator
(
models
,
task
.
target_dictionary
,
beam_size
=
args
.
beam
,
models
,
task
.
target_dictionary
,
beam_size
=
args
.
beam
,
minlen
=
args
.
min_len
,
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
unk_penalty
=
args
.
unkpen
,
len_penalty
=
args
.
lenpen
,
unk_penalty
=
args
.
unkpen
,
sampling
=
args
.
sampling
,
sampling_topk
=
args
.
sampling_topk
,
minlen
=
args
.
min_len
,
sampling
=
args
.
sampling
,
sampling_topk
=
args
.
sampling_topk
,
sampling_temperature
=
args
.
sampling_temperature
,
sampling_temperature
=
args
.
sampling_temperature
,
diverse_beam_groups
=
args
.
diverse_beam_groups
,
diverse_beam_strength
=
args
.
diverse_beam_strength
,
)
)
if
use_cuda
:
if
use_cuda
:
...
...
interactive.py
View file @
8c0ca1a0
...
@@ -90,10 +90,11 @@ def main(args):
...
@@ -90,10 +90,11 @@ def main(args):
# Initialize generator
# Initialize generator
translator
=
SequenceGenerator
(
translator
=
SequenceGenerator
(
models
,
tgt_dict
,
beam_size
=
args
.
beam
,
stop_early
=
(
not
args
.
no_early_stop
),
models
,
tgt_dict
,
beam_size
=
args
.
beam
,
minlen
=
args
.
min_len
,
normalize_scores
=
(
not
args
.
unnormalized
),
len_penalty
=
args
.
lenpen
,
stop_early
=
(
not
args
.
no_early_stop
),
normalize_scores
=
(
not
args
.
unnormalized
),
unk_penalty
=
args
.
unkpen
,
sampling
=
args
.
sampling
,
sampling_topk
=
args
.
sampling_topk
,
len_penalty
=
args
.
lenpen
,
unk_penalty
=
args
.
unkpen
,
minlen
=
args
.
min_len
,
sampling_temperature
=
args
.
sampling_temperature
sampling
=
args
.
sampling
,
sampling_topk
=
args
.
sampling_topk
,
sampling_temperature
=
args
.
sampling_temperature
,
diverse_beam_groups
=
args
.
diverse_beam_groups
,
diverse_beam_strength
=
args
.
diverse_beam_strength
,
)
)
if
use_cuda
:
if
use_cuda
:
...
...
tests/test_sequence_generator.py
View file @
8c0ca1a0
...
@@ -210,5 +210,104 @@ class TestSequenceGenerator(unittest.TestCase):
...
@@ -210,5 +210,104 @@ class TestSequenceGenerator(unittest.TestCase):
self
.
assertEqual
(
t1
.
ne
(
t2
).
long
().
sum
(),
0
)
self
.
assertEqual
(
t1
.
ne
(
t2
).
long
().
sum
(),
0
)
class
TestDiverseBeamSearch
(
unittest
.
TestCase
):
def
setUp
(
self
):
# construct dummy dictionary
d
=
test_utils
.
dummy_dictionary
(
vocab_size
=
2
)
self
.
assertEqual
(
d
.
pad
(),
1
)
self
.
assertEqual
(
d
.
eos
(),
2
)
self
.
assertEqual
(
d
.
unk
(),
3
)
self
.
eos
=
d
.
eos
()
self
.
w1
=
4
self
.
w2
=
5
# construct source data
self
.
src_tokens
=
torch
.
LongTensor
([
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
args
=
argparse
.
Namespace
()
unk
=
0.
args
.
beam_probs
=
[
# step 0:
torch
.
FloatTensor
([
# eos w1 w2
# sentence 1:
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 1
[
0.0
,
unk
,
0.9
,
0.1
],
# beam 2
# sentence 2:
[
0.0
,
unk
,
0.7
,
0.3
],
[
0.0
,
unk
,
0.7
,
0.3
],
]),
# step 1:
torch
.
FloatTensor
([
# eos w1 w2
# sentence 1:
[
0.0
,
unk
,
0.6
,
0.4
],
[
0.0
,
unk
,
0.6
,
0.4
],
# sentence 2:
[
0.25
,
unk
,
0.35
,
0.4
],
[
0.25
,
unk
,
0.35
,
0.4
],
]),
# step 2:
torch
.
FloatTensor
([
# eos w1 w2
# sentence 1:
[
1.0
,
unk
,
0.0
,
0.0
],
[
1.0
,
unk
,
0.0
,
0.0
],
# sentence 2:
[
0.9
,
unk
,
0.1
,
0.0
],
[
0.9
,
unk
,
0.1
,
0.0
],
]),
]
task
=
test_utils
.
TestTranslationTask
.
setup_task
(
args
,
d
,
d
)
self
.
model
=
task
.
build_model
(
args
)
self
.
tgt_dict
=
task
.
target_dictionary
def
test_diverse_beam_search
(
self
):
generator
=
SequenceGenerator
(
[
self
.
model
],
self
.
tgt_dict
,
beam_size
=
2
,
diverse_beam_groups
=
2
,
diverse_beam_strength
=
0.
,
)
hypos
=
generator
.
generate
(
self
.
src_tokens
,
self
.
src_lengths
)
eos
,
w1
,
w2
=
self
.
eos
,
self
.
w1
,
self
.
w2
# sentence 1, beam 1
self
.
assertHypoTokens
(
hypos
[
0
][
0
],
[
w1
,
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
0
],
[
0.9
,
0.6
,
1.0
])
# sentence 1, beam 2
self
.
assertHypoTokens
(
hypos
[
0
][
1
],
[
w1
,
w1
,
eos
])
self
.
assertHypoScore
(
hypos
[
0
][
1
],
[
0.9
,
0.6
,
1.0
])
# sentence 2, beam 1
self
.
assertHypoTokens
(
hypos
[
1
][
0
],
[
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
0
],
[
0.7
,
0.4
,
0.9
])
# sentence 2, beam 2
self
.
assertHypoTokens
(
hypos
[
1
][
1
],
[
w1
,
w2
,
eos
])
self
.
assertHypoScore
(
hypos
[
1
][
1
],
[
0.7
,
0.4
,
0.9
])
def
assertHypoTokens
(
self
,
hypo
,
tokens
):
self
.
assertTensorEqual
(
hypo
[
'tokens'
],
torch
.
LongTensor
(
tokens
))
def
assertHypoScore
(
self
,
hypo
,
pos_probs
,
normalized
=
True
,
lenpen
=
1.
):
pos_scores
=
torch
.
FloatTensor
(
pos_probs
).
log
()
self
.
assertAlmostEqual
(
hypo
[
'positional_scores'
],
pos_scores
)
self
.
assertEqual
(
pos_scores
.
numel
(),
hypo
[
'tokens'
].
numel
())
score
=
pos_scores
.
sum
()
if
normalized
:
score
/=
pos_scores
.
numel
()
**
lenpen
self
.
assertLess
(
abs
(
score
-
hypo
[
'score'
]),
1e-6
)
def
assertAlmostEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertLess
((
t1
-
t2
).
abs
().
max
(),
1e-4
)
def
assertTensorEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertEqual
(
t1
.
ne
(
t2
).
long
().
sum
(),
0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment