Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0742b15
Unverified
Commit
c0742b15
authored
Jul 25, 2023
by
Joao Gante
Committed by
GitHub
Jul 25, 2023
Browse files
Generate - add beam indices output in contrained beam search (#25042)
parent
c53a6eae
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
13 deletions
+55
-13
src/transformers/generation/beam_search.py
src/transformers/generation/beam_search.py
+35
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+20
-10
No files found.
src/transformers/generation/beam_search.py
View file @
c0742b15
...
...
@@ -43,7 +43,7 @@ PROCESS_INPUTS_DOCSTRING = r"""
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor
]
`, *optional*):
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
group_index (`int`, *optional*):
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
...
...
@@ -510,6 +510,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
scores_for_all_vocab
:
torch
.
FloatTensor
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
beam_indices
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
]:
r
"""
Args:
...
...
@@ -532,6 +533,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
...
...
@@ -597,9 +600,16 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint
=
self
.
check_completes_constraints
(
input_ids
[
batch_beam_idx
].
cpu
().
tolist
())
if
completes_constraint
:
if
beam_indices
is
not
None
:
beam_index
=
beam_indices
[
batch_beam_idx
]
beam_index
=
beam_index
+
(
batch_beam_idx
,)
else
:
beam_index
=
None
beam_hyp
.
add
(
input_ids
[
batch_beam_idx
].
clone
(),
next_score
.
item
(),
beam_indices
=
beam_index
,
)
else
:
# add next predicted token since it is not eos_token
...
...
@@ -794,6 +804,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
max_length
:
int
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
beam_indices
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
LongTensor
]:
batch_size
=
len
(
self
.
_beam_hyps
)
...
...
@@ -816,7 +827,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint
=
self
.
check_completes_constraints
(
final_tokens
.
cpu
().
tolist
())
if
completes_constraint
:
beam_hyp
.
add
(
final_tokens
,
final_score
)
beam_index
=
beam_indices
[
batch_beam_idx
]
if
beam_indices
is
not
None
else
None
beam_hyp
.
add
(
final_tokens
,
final_score
,
beam_indices
=
beam_index
)
ids_collect
.
append
(
beam_id
)
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
...
...
@@ -834,6 +846,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# select the best hypotheses
sent_lengths
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
)
best
=
[]
best_indices
=
[]
best_scores
=
torch
.
zeros
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
# retrieve best hypotheses
...
...
@@ -843,10 +856,15 @@ class ConstrainedBeamSearchScorer(BeamScorer):
best_hyp_tuple
=
sorted_hyps
.
pop
()
best_score
=
best_hyp_tuple
[
0
]
best_hyp
=
best_hyp_tuple
[
1
]
best_index
=
best_hyp_tuple
[
2
]
sent_lengths
[
self
.
num_beam_hyps_to_keep
*
i
+
j
]
=
len
(
best_hyp
)
# append to lists
best
.
append
(
best_hyp
)
# append indices to list
best_indices
.
append
(
best_index
)
best_scores
[
i
*
self
.
num_beam_hyps_to_keep
+
j
]
=
best_score
# prepare for adding eos
...
...
@@ -854,15 +872,28 @@ class ConstrainedBeamSearchScorer(BeamScorer):
sent_max_len
=
min
(
sent_lengths_max
,
max_length
)
if
max_length
is
not
None
else
sent_lengths_max
decoded
:
torch
.
LongTensor
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
sent_max_len
)
if
len
(
best_indices
)
>
0
and
best_indices
[
0
]
is
not
None
:
indices
:
torch
.
LongTensor
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
sent_max_len
)
else
:
indices
=
None
# shorter batches are padded if needed
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
if
pad_token_id
is
None
:
raise
ValueError
(
"`pad_token_id` has to be defined"
)
decoded
.
fill_
(
pad_token_id
)
if
indices
is
not
None
:
indices
.
fill_
(
-
1
)
# fill with hypotheses and eos_token_id if the latter fits in
for
i
,
hypo
in
enumerate
(
best
):
for
i
,
(
hypo
,
best_idx
)
in
enumerate
(
zip
(
best
,
best_indices
)
):
decoded
[
i
,
:
sent_lengths
[
i
]]
=
hypo
if
indices
is
not
None
:
indices
[
i
,
:
len
(
best_idx
)]
=
torch
.
tensor
(
best_idx
)
if
sent_lengths
[
i
]
<
sent_max_len
:
# inserting only the first eos_token_id
decoded
[
i
,
sent_lengths
[
i
]]
=
eos_token_id
[
0
]
...
...
@@ -871,6 +902,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
{
"sequences"
:
decoded
,
"sequence_scores"
:
best_scores
,
"beam_indices"
:
indices
,
}
)
...
...
src/transformers/generation/utils.py
View file @
c0742b15
...
...
@@ -4000,8 +4000,21 @@ class GenerationMixin:
else
self
.
generation_config
.
return_dict_in_generate
)
batch_size
=
len
(
constrained_beam_scorer
.
_beam_hyps
)
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
)
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
beam_indices
=
(
tuple
(()
for
_
in
range
(
batch_beam_size
))
if
(
return_dict_in_generate
and
output_scores
)
else
None
)
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
...
...
@@ -4013,16 +4026,6 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
batch_size
=
len
(
constrained_beam_scorer
.
_beam_hyps
)
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
...
...
@@ -4107,6 +4110,7 @@ class GenerationMixin:
scores_for_all_vocab
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
beam_indices
=
beam_indices
,
)
beam_scores
=
beam_outputs
[
"next_beam_scores"
]
beam_next_tokens
=
beam_outputs
[
"next_beam_tokens"
]
...
...
@@ -4119,6 +4123,9 @@ class GenerationMixin:
if
model_kwargs
[
"past_key_values"
]
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_reorder_cache
(
model_kwargs
[
"past_key_values"
],
beam_idx
)
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
tuple
((
beam_indices
[
beam_idx
[
i
]]
+
(
beam_idx
[
i
],)
for
i
in
range
(
len
(
beam_indices
))))
# increase cur_len
cur_len
=
cur_len
+
1
...
...
@@ -4136,6 +4143,7 @@ class GenerationMixin:
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
max_length
=
stopping_criteria
.
max_length
,
beam_indices
=
beam_indices
,
)
if
return_dict_in_generate
:
...
...
@@ -4146,6 +4154,7 @@ class GenerationMixin:
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
decoder_attentions
=
decoder_attentions
,
...
...
@@ -4157,6 +4166,7 @@ class GenerationMixin:
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment