Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0742b15
Unverified
Commit
c0742b15
authored
Jul 25, 2023
by
Joao Gante
Committed by
GitHub
Jul 25, 2023
Browse files
Generate - add beam indices output in contrained beam search (#25042)
parent
c53a6eae
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
13 deletions
+55
-13
src/transformers/generation/beam_search.py
src/transformers/generation/beam_search.py
+35
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+20
-10
No files found.
src/transformers/generation/beam_search.py
View file @
c0742b15
...
...
@@ -43,7 +43,7 @@ PROCESS_INPUTS_DOCSTRING = r"""
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor
]
`, *optional*):
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
group_index (`int`, *optional*):
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
...
...
@@ -510,6 +510,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
scores_for_all_vocab
:
torch
.
FloatTensor
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
beam_indices
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
]:
r
"""
Args:
...
...
@@ -532,6 +533,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
Return:
`UserDict`: A dictionary composed of the fields as defined above:
...
...
@@ -597,9 +600,16 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint
=
self
.
check_completes_constraints
(
input_ids
[
batch_beam_idx
].
cpu
().
tolist
())
if
completes_constraint
:
if
beam_indices
is
not
None
:
beam_index
=
beam_indices
[
batch_beam_idx
]
beam_index
=
beam_index
+
(
batch_beam_idx
,)
else
:
beam_index
=
None
beam_hyp
.
add
(
input_ids
[
batch_beam_idx
].
clone
(),
next_score
.
item
(),
beam_indices
=
beam_index
,
)
else
:
# add next predicted token since it is not eos_token
...
...
@@ -794,6 +804,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
max_length
:
int
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
beam_indices
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
LongTensor
]:
batch_size
=
len
(
self
.
_beam_hyps
)
...
...
@@ -816,7 +827,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint
=
self
.
check_completes_constraints
(
final_tokens
.
cpu
().
tolist
())
if
completes_constraint
:
beam_hyp
.
add
(
final_tokens
,
final_score
)
beam_index
=
beam_indices
[
batch_beam_idx
]
if
beam_indices
is
not
None
else
None
beam_hyp
.
add
(
final_tokens
,
final_score
,
beam_indices
=
beam_index
)
ids_collect
.
append
(
beam_id
)
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
...
...
@@ -834,6 +846,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# select the best hypotheses
sent_lengths
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
)
best
=
[]
best_indices
=
[]
best_scores
=
torch
.
zeros
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
# retrieve best hypotheses
...
...
@@ -843,10 +856,15 @@ class ConstrainedBeamSearchScorer(BeamScorer):
best_hyp_tuple
=
sorted_hyps
.
pop
()
best_score
=
best_hyp_tuple
[
0
]
best_hyp
=
best_hyp_tuple
[
1
]
best_index
=
best_hyp_tuple
[
2
]
sent_lengths
[
self
.
num_beam_hyps_to_keep
*
i
+
j
]
=
len
(
best_hyp
)
# append to lists
best
.
append
(
best_hyp
)
# append indices to list
best_indices
.
append
(
best_index
)
best_scores
[
i
*
self
.
num_beam_hyps_to_keep
+
j
]
=
best_score
# prepare for adding eos
...
...
@@ -854,15 +872,28 @@ class ConstrainedBeamSearchScorer(BeamScorer):
sent_max_len
=
min
(
sent_lengths_max
,
max_length
)
if
max_length
is
not
None
else
sent_lengths_max
decoded
:
torch
.
LongTensor
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
sent_max_len
)
if
len
(
best_indices
)
>
0
and
best_indices
[
0
]
is
not
None
:
indices
:
torch
.
LongTensor
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
sent_max_len
)
else
:
indices
=
None
# shorter batches are padded if needed
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
if
pad_token_id
is
None
:
raise
ValueError
(
"`pad_token_id` has to be defined"
)
decoded
.
fill_
(
pad_token_id
)
if
indices
is
not
None
:
indices
.
fill_
(
-
1
)
# fill with hypotheses and eos_token_id if the latter fits in
for
i
,
hypo
in
enumerate
(
best
):
for
i
,
(
hypo
,
best_idx
)
in
enumerate
(
zip
(
best
,
best_indices
)
):
decoded
[
i
,
:
sent_lengths
[
i
]]
=
hypo
if
indices
is
not
None
:
indices
[
i
,
:
len
(
best_idx
)]
=
torch
.
tensor
(
best_idx
)
if
sent_lengths
[
i
]
<
sent_max_len
:
# inserting only the first eos_token_id
decoded
[
i
,
sent_lengths
[
i
]]
=
eos_token_id
[
0
]
...
...
@@ -871,6 +902,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
{
"sequences"
:
decoded
,
"sequence_scores"
:
best_scores
,
"beam_indices"
:
indices
,
}
)
...
...
src/transformers/generation/utils.py
View file @
c0742b15
...
...
@@ -4000,8 +4000,21 @@ class GenerationMixin:
else
self
.
generation_config
.
return_dict_in_generate
)
batch_size
=
len
(
constrained_beam_scorer
.
_beam_hyps
)
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
)
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
beam_indices
=
(
tuple
(()
for
_
in
range
(
batch_beam_size
))
if
(
return_dict_in_generate
and
output_scores
)
else
None
)
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
...
...
@@ -4013,16 +4026,6 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
batch_size
=
len
(
constrained_beam_scorer
.
_beam_hyps
)
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
...
...
@@ -4107,6 +4110,7 @@ class GenerationMixin:
scores_for_all_vocab
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
beam_indices
=
beam_indices
,
)
beam_scores
=
beam_outputs
[
"next_beam_scores"
]
beam_next_tokens
=
beam_outputs
[
"next_beam_tokens"
]
...
...
@@ -4119,6 +4123,9 @@ class GenerationMixin:
if
model_kwargs
[
"past_key_values"
]
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_reorder_cache
(
model_kwargs
[
"past_key_values"
],
beam_idx
)
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
tuple
((
beam_indices
[
beam_idx
[
i
]]
+
(
beam_idx
[
i
],)
for
i
in
range
(
len
(
beam_indices
))))
# increase cur_len
cur_len
=
cur_len
+
1
...
...
@@ -4136,6 +4143,7 @@ class GenerationMixin:
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
max_length
=
stopping_criteria
.
max_length
,
beam_indices
=
beam_indices
,
)
if
return_dict_in_generate
:
...
...
@@ -4146,6 +4154,7 @@ class GenerationMixin:
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
decoder_attentions
=
decoder_attentions
,
...
...
@@ -4157,6 +4166,7 @@ class GenerationMixin:
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment