Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c0742b15
Unverified
Commit
c0742b15
authored
Jul 25, 2023
by
Joao Gante
Committed by
GitHub
Jul 25, 2023
Browse files
Generate - add beam indices output in contrained beam search (#25042)
parent
c53a6eae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
13 deletions
+55
-13
src/transformers/generation/beam_search.py
src/transformers/generation/beam_search.py
+35
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+20
-10
No files found.
src/transformers/generation/beam_search.py
View file @
c0742b15
...
@@ -43,7 +43,7 @@ PROCESS_INPUTS_DOCSTRING = r"""
...
@@ -43,7 +43,7 @@ PROCESS_INPUTS_DOCSTRING = r"""
The id of the *padding* token.
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor
]
`, *optional*):
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
Beam indices indicating to which beam hypothesis each token correspond.
group_index (`int`, *optional*):
group_index (`int`, *optional*):
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
...
@@ -510,6 +510,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -510,6 +510,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
scores_for_all_vocab
:
torch
.
FloatTensor
,
scores_for_all_vocab
:
torch
.
FloatTensor
,
pad_token_id
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
beam_indices
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
]:
r
"""
r
"""
Args:
Args:
...
@@ -532,6 +533,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -532,6 +533,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
The id of the *padding* token.
The id of the *padding* token.
eos_token_id (`Union[int, List[int]]`, *optional*):
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
beam_indices (`torch.LongTensor`, *optional*):
Beam indices indicating to which beam hypothesis each token correspond.
Return:
Return:
`UserDict`: A dictionary composed of the fields as defined above:
`UserDict`: A dictionary composed of the fields as defined above:
...
@@ -597,9 +600,16 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -597,9 +600,16 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint
=
self
.
check_completes_constraints
(
input_ids
[
batch_beam_idx
].
cpu
().
tolist
())
completes_constraint
=
self
.
check_completes_constraints
(
input_ids
[
batch_beam_idx
].
cpu
().
tolist
())
if
completes_constraint
:
if
completes_constraint
:
if
beam_indices
is
not
None
:
beam_index
=
beam_indices
[
batch_beam_idx
]
beam_index
=
beam_index
+
(
batch_beam_idx
,)
else
:
beam_index
=
None
beam_hyp
.
add
(
beam_hyp
.
add
(
input_ids
[
batch_beam_idx
].
clone
(),
input_ids
[
batch_beam_idx
].
clone
(),
next_score
.
item
(),
next_score
.
item
(),
beam_indices
=
beam_index
,
)
)
else
:
else
:
# add next predicted token since it is not eos_token
# add next predicted token since it is not eos_token
...
@@ -794,6 +804,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -794,6 +804,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
max_length
:
int
,
max_length
:
int
,
pad_token_id
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
beam_indices
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
LongTensor
]:
)
->
Tuple
[
torch
.
LongTensor
]:
batch_size
=
len
(
self
.
_beam_hyps
)
batch_size
=
len
(
self
.
_beam_hyps
)
...
@@ -816,7 +827,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -816,7 +827,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint
=
self
.
check_completes_constraints
(
final_tokens
.
cpu
().
tolist
())
completes_constraint
=
self
.
check_completes_constraints
(
final_tokens
.
cpu
().
tolist
())
if
completes_constraint
:
if
completes_constraint
:
beam_hyp
.
add
(
final_tokens
,
final_score
)
beam_index
=
beam_indices
[
batch_beam_idx
]
if
beam_indices
is
not
None
else
None
beam_hyp
.
add
(
final_tokens
,
final_score
,
beam_indices
=
beam_index
)
ids_collect
.
append
(
beam_id
)
ids_collect
.
append
(
beam_id
)
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
...
@@ -834,6 +846,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -834,6 +846,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# select the best hypotheses
# select the best hypotheses
sent_lengths
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
)
sent_lengths
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
)
best
=
[]
best
=
[]
best_indices
=
[]
best_scores
=
torch
.
zeros
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
best_scores
=
torch
.
zeros
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
# retrieve best hypotheses
# retrieve best hypotheses
...
@@ -843,10 +856,15 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -843,10 +856,15 @@ class ConstrainedBeamSearchScorer(BeamScorer):
best_hyp_tuple
=
sorted_hyps
.
pop
()
best_hyp_tuple
=
sorted_hyps
.
pop
()
best_score
=
best_hyp_tuple
[
0
]
best_score
=
best_hyp_tuple
[
0
]
best_hyp
=
best_hyp_tuple
[
1
]
best_hyp
=
best_hyp_tuple
[
1
]
best_index
=
best_hyp_tuple
[
2
]
sent_lengths
[
self
.
num_beam_hyps_to_keep
*
i
+
j
]
=
len
(
best_hyp
)
sent_lengths
[
self
.
num_beam_hyps_to_keep
*
i
+
j
]
=
len
(
best_hyp
)
# append to lists
# append to lists
best
.
append
(
best_hyp
)
best
.
append
(
best_hyp
)
# append indices to list
best_indices
.
append
(
best_index
)
best_scores
[
i
*
self
.
num_beam_hyps_to_keep
+
j
]
=
best_score
best_scores
[
i
*
self
.
num_beam_hyps_to_keep
+
j
]
=
best_score
# prepare for adding eos
# prepare for adding eos
...
@@ -854,15 +872,28 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -854,15 +872,28 @@ class ConstrainedBeamSearchScorer(BeamScorer):
sent_max_len
=
min
(
sent_lengths_max
,
max_length
)
if
max_length
is
not
None
else
sent_lengths_max
sent_max_len
=
min
(
sent_lengths_max
,
max_length
)
if
max_length
is
not
None
else
sent_lengths_max
decoded
:
torch
.
LongTensor
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
sent_max_len
)
decoded
:
torch
.
LongTensor
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
sent_max_len
)
if
len
(
best_indices
)
>
0
and
best_indices
[
0
]
is
not
None
:
indices
:
torch
.
LongTensor
=
input_ids
.
new
(
batch_size
*
self
.
num_beam_hyps_to_keep
,
sent_max_len
)
else
:
indices
=
None
# shorter batches are padded if needed
# shorter batches are padded if needed
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
if
sent_lengths
.
min
().
item
()
!=
sent_lengths
.
max
().
item
():
if
pad_token_id
is
None
:
if
pad_token_id
is
None
:
raise
ValueError
(
"`pad_token_id` has to be defined"
)
raise
ValueError
(
"`pad_token_id` has to be defined"
)
decoded
.
fill_
(
pad_token_id
)
decoded
.
fill_
(
pad_token_id
)
if
indices
is
not
None
:
indices
.
fill_
(
-
1
)
# fill with hypotheses and eos_token_id if the latter fits in
# fill with hypotheses and eos_token_id if the latter fits in
for
i
,
hypo
in
enumerate
(
best
):
for
i
,
(
hypo
,
best_idx
)
in
enumerate
(
zip
(
best
,
best_indices
)
):
decoded
[
i
,
:
sent_lengths
[
i
]]
=
hypo
decoded
[
i
,
:
sent_lengths
[
i
]]
=
hypo
if
indices
is
not
None
:
indices
[
i
,
:
len
(
best_idx
)]
=
torch
.
tensor
(
best_idx
)
if
sent_lengths
[
i
]
<
sent_max_len
:
if
sent_lengths
[
i
]
<
sent_max_len
:
# inserting only the first eos_token_id
# inserting only the first eos_token_id
decoded
[
i
,
sent_lengths
[
i
]]
=
eos_token_id
[
0
]
decoded
[
i
,
sent_lengths
[
i
]]
=
eos_token_id
[
0
]
...
@@ -871,6 +902,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
...
@@ -871,6 +902,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
{
{
"sequences"
:
decoded
,
"sequences"
:
decoded
,
"sequence_scores"
:
best_scores
,
"sequence_scores"
:
best_scores
,
"beam_indices"
:
indices
,
}
}
)
)
...
...
src/transformers/generation/utils.py
View file @
c0742b15
...
@@ -4000,8 +4000,21 @@ class GenerationMixin:
...
@@ -4000,8 +4000,21 @@ class GenerationMixin:
else
self
.
generation_config
.
return_dict_in_generate
else
self
.
generation_config
.
return_dict_in_generate
)
)
batch_size
=
len
(
constrained_beam_scorer
.
_beam_hyps
)
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
)
# init attention / hidden states / scores tuples
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
beam_indices
=
(
tuple
(()
for
_
in
range
(
batch_beam_size
))
if
(
return_dict_in_generate
and
output_scores
)
else
None
)
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
cross_attentions
=
()
if
(
return_dict_in_generate
and
output_attentions
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
decoder_hidden_states
=
()
if
(
return_dict_in_generate
and
output_hidden_states
)
else
None
...
@@ -4013,16 +4026,6 @@ class GenerationMixin:
...
@@ -4013,16 +4026,6 @@ class GenerationMixin:
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
model_kwargs
[
"encoder_outputs"
].
get
(
"hidden_states"
)
if
output_hidden_states
else
None
)
)
batch_size
=
len
(
constrained_beam_scorer
.
_beam_hyps
)
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
f
"Batch dimension of `input_ids` should be
{
num_beams
*
batch_size
}
, but is
{
batch_beam_size
}
."
)
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
beam_scores
=
torch
.
zeros
((
batch_size
,
num_beams
),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
...
@@ -4107,6 +4110,7 @@ class GenerationMixin:
...
@@ -4107,6 +4110,7 @@ class GenerationMixin:
scores_for_all_vocab
,
scores_for_all_vocab
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
beam_indices
=
beam_indices
,
)
)
beam_scores
=
beam_outputs
[
"next_beam_scores"
]
beam_scores
=
beam_outputs
[
"next_beam_scores"
]
beam_next_tokens
=
beam_outputs
[
"next_beam_tokens"
]
beam_next_tokens
=
beam_outputs
[
"next_beam_tokens"
]
...
@@ -4119,6 +4123,9 @@ class GenerationMixin:
...
@@ -4119,6 +4123,9 @@ class GenerationMixin:
if
model_kwargs
[
"past_key_values"
]
is
not
None
:
if
model_kwargs
[
"past_key_values"
]
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_reorder_cache
(
model_kwargs
[
"past_key_values"
],
beam_idx
)
model_kwargs
[
"past_key_values"
]
=
self
.
_reorder_cache
(
model_kwargs
[
"past_key_values"
],
beam_idx
)
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
tuple
((
beam_indices
[
beam_idx
[
i
]]
+
(
beam_idx
[
i
],)
for
i
in
range
(
len
(
beam_indices
))))
# increase cur_len
# increase cur_len
cur_len
=
cur_len
+
1
cur_len
=
cur_len
+
1
...
@@ -4136,6 +4143,7 @@ class GenerationMixin:
...
@@ -4136,6 +4143,7 @@ class GenerationMixin:
pad_token_id
=
pad_token_id
,
pad_token_id
=
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
max_length
=
stopping_criteria
.
max_length
,
max_length
=
stopping_criteria
.
max_length
,
beam_indices
=
beam_indices
,
)
)
if
return_dict_in_generate
:
if
return_dict_in_generate
:
...
@@ -4146,6 +4154,7 @@ class GenerationMixin:
...
@@ -4146,6 +4154,7 @@ class GenerationMixin:
sequences
=
sequence_outputs
[
"sequences"
],
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
scores
=
scores
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
encoder_attentions
=
encoder_attentions
,
encoder_attentions
=
encoder_attentions
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
decoder_attentions
=
decoder_attentions
,
decoder_attentions
=
decoder_attentions
,
...
@@ -4157,6 +4166,7 @@ class GenerationMixin:
...
@@ -4157,6 +4166,7 @@ class GenerationMixin:
sequences
=
sequence_outputs
[
"sequences"
],
sequences
=
sequence_outputs
[
"sequences"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
sequences_scores
=
sequence_outputs
[
"sequence_scores"
],
scores
=
scores
,
scores
=
scores
,
beam_indices
=
sequence_outputs
[
"beam_indices"
],
attentions
=
decoder_attentions
,
attentions
=
decoder_attentions
,
hidden_states
=
decoder_hidden_states
,
hidden_states
=
decoder_hidden_states
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment