Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d4fc1eb4
Unverified
Commit
d4fc1eb4
authored
Jan 19, 2024
by
Saibo-creator
Committed by
GitHub
Jan 19, 2024
Browse files
feat: Sequential beam search (#26304)
parent
268fc1fd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
235 additions
and
43 deletions
+235
-43
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+2
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+187
-42
tests/generation/test_utils.py
tests/generation/test_utils.py
+46
-0
No files found.
src/transformers/generation/configuration_utils.py
View file @
d4fc1eb4
...
@@ -200,7 +200,8 @@ class GenerationConfig(PushToHubMixin):
...
@@ -200,7 +200,8 @@ class GenerationConfig(PushToHubMixin):
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory.
Used with beam search and contrastive search.
> Parameters that define the output variables of `generate`
> Parameters that define the output variables of `generate`
...
...
src/transformers/generation/utils.py
View file @
d4fc1eb4
...
@@ -1558,6 +1558,7 @@ class GenerationMixin:
...
@@ -1558,6 +1558,7 @@ class GenerationMixin:
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
sequential
=
generation_config
.
low_memory
,
**
model_kwargs
,
**
model_kwargs
,
)
)
...
@@ -1951,8 +1952,7 @@ class GenerationMixin:
...
@@ -1951,8 +1952,7 @@ class GenerationMixin:
model_kwargs
[
"past_key_values"
]
=
tuple
(
new_key_values
)
model_kwargs
[
"past_key_values"
]
=
tuple
(
new_key_values
)
if
sequential
:
if
sequential
:
all_outputs
=
{
key
:
[]
for
key
in
outputs
}
# defined in first loop iteration
all_outputs
=
[]
all_last_hstates
,
all_hstates
,
all_logits
=
[],
[],
[]
for
i
in
range
(
top_k
):
for
i
in
range
(
top_k
):
# compute the candidate tokens by the language model and collect their hidden_states
# compute the candidate tokens by the language model and collect their hidden_states
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
[:,
i
].
view
(
-
1
,
1
),
**
model_kwargs
)
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
[:,
i
].
view
(
-
1
,
1
),
**
model_kwargs
)
...
@@ -1963,32 +1963,8 @@ class GenerationMixin:
...
@@ -1963,32 +1963,8 @@ class GenerationMixin:
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
)
)
for
key
in
all_outputs
:
all_outputs
.
append
(
outputs
)
all_outputs
[
key
].
append
(
outputs
[
key
])
outputs
=
stack_model_outputs
(
all_outputs
)
if
self
.
config
.
is_encoder_decoder
:
next_hidden
=
outputs
.
decoder_hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
decoder_hidden_states
else
:
next_hidden
=
outputs
.
hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
hidden_states
all_last_hstates
.
append
(
torch
.
squeeze
(
next_hidden
,
0
))
all_hstates
.
append
(
full_hidden_states
)
all_logits
.
append
(
outputs
.
logits
[:,
-
1
,
:])
# stack hidden states
next_hidden
=
torch
.
stack
([
all_last_hstates
[
i
]
for
i
in
range
(
top_k
)],
dim
=
0
)
final_full_hstates
=
[
0
for
i
in
range
(
len
(
full_hidden_states
))]
for
layer
in
range
(
len
(
full_hidden_states
)):
final_full_hstates
[
layer
]
=
torch
.
stack
(
[
torch
.
squeeze
(
all_hstates
[
i
][
layer
],
0
)
for
i
in
range
(
top_k
)],
dim
=
0
)
full_hidden_states
=
tuple
(
final_full_hstates
)
# stack logits
logits
=
torch
.
cat
(
all_logits
,
dim
=
0
)
else
:
else
:
# compute the candidate tokens by the language model and collect their hidden_states
# compute the candidate tokens by the language model and collect their hidden_states
...
@@ -2747,6 +2723,7 @@ class GenerationMixin:
...
@@ -2747,6 +2723,7 @@ class GenerationMixin:
output_scores
:
Optional
[
bool
]
=
None
,
output_scores
:
Optional
[
bool
]
=
None
,
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
synced_gpus
:
bool
=
False
,
synced_gpus
:
bool
=
False
,
sequential
:
Optional
[
bool
]
=
None
,
**
model_kwargs
,
**
model_kwargs
,
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
)
->
Union
[
GenerateBeamOutput
,
torch
.
LongTensor
]:
r
"""
r
"""
...
@@ -2792,6 +2769,10 @@ class GenerationMixin:
...
@@ -2792,6 +2769,10 @@ class GenerationMixin:
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
sequential (`bool`, defaults to `False`):
By default, beam search has `batch_size * num_beams` as effective batch size (see `beam_search()` for
more details). This flag will avoid parallelizing the beam search and will instead run beam search
sequentially.
model_kwargs:
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
an encoder-decoder model the kwargs should include `encoder_outputs`.
...
@@ -2858,6 +2839,7 @@ class GenerationMixin:
...
@@ -2858,6 +2839,7 @@ class GenerationMixin:
# init values
# init values
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
LogitsProcessorList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
sequential
=
sequential
if
sequential
is
not
None
else
self
.
generation_config
.
low_memory
if
max_length
is
not
None
:
if
max_length
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
"`max_length` is deprecated in this function, use"
"`max_length` is deprecated in this function, use"
...
@@ -2932,6 +2914,33 @@ class GenerationMixin:
...
@@ -2932,6 +2914,33 @@ class GenerationMixin:
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
# if sequential is True, split the input to batches of batch_size and run sequentially
if
sequential
:
if
any
(
model_name
in
self
.
__class__
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"bloom"
,
"ctrl"
,
"gpt_bigcode"
,
"transo_xl"
,
"xlnet"
,
"cpm"
]
):
raise
RuntimeError
(
f
"Currently generation for
{
self
.
__class__
.
__name__
}
is not supported "
f
"for `low_memory beam_search`. Please open an issue on GitHub if you need this feature."
)
inputs_per_sub_batches
=
_split_model_inputs
(
model_inputs
,
split_size
=
batch_size
,
full_batch_size
=
batch_beam_size
)
outputs_per_sub_batch
=
[
self
(
**
inputs_per_sub_batch
,
return_dict
=
True
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
)
for
inputs_per_sub_batch
in
inputs_per_sub_batches
]
outputs
=
stack_model_outputs
(
outputs_per_sub_batch
)
else
:
# Unchanged original behavior
outputs
=
self
(
outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
return_dict
=
True
,
return_dict
=
True
,
...
@@ -4656,3 +4665,139 @@ def _ranking_fast(
...
@@ -4656,3 +4665,139 @@ def _ranking_fast(
contrastive_score
=
torch
.
stack
(
torch
.
split
(
contrastive_score
,
beam_width
))
# [B, K]
contrastive_score
=
torch
.
stack
(
torch
.
split
(
contrastive_score
,
beam_width
))
# [B, K]
_
,
selected_idx
=
contrastive_score
.
max
(
dim
=-
1
)
# [B]
_
,
selected_idx
=
contrastive_score
.
max
(
dim
=-
1
)
# [B]
return
selected_idx
return
selected_idx
def
_split
(
data
,
full_batch_size
:
int
,
split_size
:
int
=
None
):
"""
Takes care of three cases:
1. data is a tensor: e.g. last_hidden_state, pooler_output etc. split them on the batch_size dim
2. data is a tuple: e.g. hidden_states, attentions etc. Keep the tuple as it is and split each tensor in it and
return a list of tuples
3. data is a tuple of tuples, e.g. past_key_values. Keep the tuple as it is and split each tuple in it and
return a list of tuples of tuples
(see documentation of ModelOutput)
"""
if
data
is
None
:
return
[
None
]
*
(
full_batch_size
//
split_size
)
if
isinstance
(
data
,
torch
.
Tensor
):
return
[
data
[
i
:
i
+
split_size
]
for
i
in
range
(
0
,
full_batch_size
,
split_size
)]
elif
isinstance
(
data
,
tuple
):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if
isinstance
(
data
[
0
],
tuple
):
return
[
tuple
(
tuple
(
tensor
[
i
:
i
+
split_size
]
for
tensor
in
inner_tuple
)
for
inner_tuple
in
data
)
for
i
in
range
(
0
,
full_batch_size
,
split_size
)
]
else
:
return
[
tuple
(
sub_tensor
[
i
:
i
+
split_size
]
for
sub_tensor
in
data
)
for
i
in
range
(
0
,
full_batch_size
,
split_size
)
]
else
:
raise
ValueError
(
f
"Unexpected attribute type:
{
type
(
data
)
}
"
)
def
_split_model_inputs
(
model_input
:
Union
[
ModelOutput
,
Dict
],
split_size
:
int
,
full_batch_size
:
int
)
->
List
[
Union
[
ModelOutput
,
Dict
]]:
"""
Split a ModelOutput object (or its subclasses) or Dict into a list of same-class objects based on a specified split
size. The input object is dict when it was prepared for forward pass and ModelOutput when it was returned from
previous forward pass.
"""
# Edge case: if model_input is None, return a list of Nones
# this happens with Whisper where encoder_outputs is None
if
model_input
is
None
:
return
[
model_input
]
*
(
full_batch_size
//
split_size
)
# Infer the class from the object
model_output_cls
=
type
(
model_input
)
if
(
full_batch_size
%
split_size
)
!=
0
:
raise
ValueError
(
"`full_batch_size` must be divisible by `split_size`"
)
if
split_size
>
full_batch_size
:
raise
ValueError
(
"`split_size` must be smaller or equal to `full_batch_size`"
)
# Helper function to split tensors or tuples of tensors
# Find all the dataclass fields (e.g., last_hidden_state, pooler_output etc.) and split them
keys
=
(
model_input
.
__dataclass_fields__
.
keys
()
if
hasattr
(
model_input
,
"__dataclass_fields__"
)
else
model_input
.
keys
()
)
# We only keep keys that are in the model_input
keys
=
[
k
for
k
in
keys
if
k
in
model_input
]
# Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys
=
[
k
for
k
in
keys
if
isinstance
(
model_input
[
k
],
bool
)]
non_bool_keys
=
[
k
for
k
in
keys
if
not
isinstance
(
model_input
[
k
],
bool
)
and
not
k
==
"encoder_outputs"
]
# we split the tensors and tuples of tensors
data_split_list
=
[
{
k
:
_split
(
model_input
[
k
],
full_batch_size
,
split_size
)[
i
]
for
k
in
non_bool_keys
}
for
i
in
range
(
full_batch_size
//
split_size
)
]
# bool values are the same and replicated for each split
bool_data
=
{
k
:
model_input
[
k
]
for
k
in
bool_keys
}
# encoder_outputs is a ModelOutput object and should be split by its own
if
"encoder_outputs"
in
model_input
:
encoder_outputs_split
=
_split_model_inputs
(
model_input
[
"encoder_outputs"
],
split_size
,
full_batch_size
)
data_split_list
=
[
{
**
data_split
,
"encoder_outputs"
:
encoder_outputs_split
[
i
]}
for
i
,
data_split
in
enumerate
(
data_split_list
)
]
# Convert each dictionary in the list to an object of the inferred class
split_model_inputs
:
List
[
Union
[
ModelOutput
,
Dict
]]
=
[
model_output_cls
(
**
data_split
,
**
bool_data
)
for
data_split
in
data_split_list
]
return
split_model_inputs
def
stack_model_outputs
(
model_outputs
:
List
[
ModelOutput
])
->
ModelOutput
:
"""
Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
specific ModelOutput subclass from the list provided.
"""
if
not
model_outputs
:
raise
ValueError
(
"Input list is empty."
)
# Infer the class from the first object in the list
model_output_cls
=
type
(
model_outputs
[
0
])
# Ensure all objects are of the same type
if
not
all
(
isinstance
(
obj
,
model_output_cls
)
for
obj
in
model_outputs
):
raise
ValueError
(
"All elements in the list should be of the same type."
)
# Helper function to concat tensors or tuples of tensors
def
_concat
(
data
):
"""
Reverse of `_split` function above.
"""
if
any
(
data
is
None
for
data
in
data
):
return
None
if
isinstance
(
data
[
0
],
torch
.
Tensor
):
return
torch
.
cat
(
data
,
dim
=
0
)
elif
isinstance
(
data
[
0
],
tuple
):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if
isinstance
(
data
[
0
][
0
],
tuple
):
return
tuple
(
tuple
(
torch
.
cat
([
attr
[
i
][
j
]
for
attr
in
data
],
dim
=
0
)
for
j
in
range
(
len
(
data
[
0
][
0
])))
for
i
in
range
(
len
(
data
[
0
]))
)
else
:
return
tuple
(
torch
.
cat
([
attr
[
i
]
for
attr
in
data
],
dim
=
0
)
for
i
in
range
(
len
(
data
[
0
])))
elif
isinstance
(
data
[
0
],
(
int
,
float
)):
# If the elements are integers or floats, return a tensor
return
torch
.
tensor
(
data
)
else
:
raise
ValueError
(
f
"Unexpected attribute type:
{
type
(
data
[
0
])
}
"
)
# Use a dictionary comprehension to gather attributes from all objects and concatenate them
concatenated_data
=
{
k
:
_concat
([
getattr
(
model_output
,
k
)
for
model_output
in
model_outputs
])
for
k
in
model_output_cls
.
__dataclass_fields__
.
keys
()
}
# Return a new object of the inferred class with the concatenated attributes
return
model_output_cls
(
**
concatenated_data
)
tests/generation/test_utils.py
View file @
d4fc1eb4
...
@@ -1539,6 +1539,39 @@ class GenerationTesterMixin:
...
@@ -1539,6 +1539,39 @@ class GenerationTesterMixin:
)
)
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
def
test_beam_search_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"bloom"
,
"ctrl"
,
"gptbigcode"
,
"transo_xl"
,
"xlnet"
,
"cpm"
,
]
):
self
.
skipTest
(
"May fix in the future: need model-specific fixes"
)
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
2
)
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
config
.
use_cache
=
True
config
.
is_decoder
=
True
# test output equality of low versus high memory
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
low_output
=
model
.
generate
(
input_ids
,
max_new_tokens
=
8
,
num_beams
=
5
,
early_stopping
=
True
,
low_memory
=
True
)
high_output
=
model
.
generate
(
input_ids
,
max_new_tokens
=
8
,
num_beams
=
5
,
early_stopping
=
True
,
low_memory
=
False
)
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
@
is_flaky
()
# Read NOTE (1) below. If there are API issues, all attempts will fail.
@
is_flaky
()
# Read NOTE (1) below. If there are API issues, all attempts will fail.
def
test_assisted_decoding_matches_greedy_search
(
self
):
def
test_assisted_decoding_matches_greedy_search
(
self
):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# This test ensures that the assisted generation does not introduce output changes over greedy search.
...
@@ -2766,6 +2799,19 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
...
@@ -2766,6 +2799,19 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self
.
assertTrue
(
torch
.
allclose
(
transition_scores_sum
,
outputs
.
sequences_scores
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
transition_scores_sum
,
outputs
.
sequences_scores
,
atol
=
1e-3
))
def
test_beam_search_low_memory
(
self
):
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"gpt2"
)
tokenizer
.
pad_token_id
=
tokenizer
.
eos_token_id
model_inputs
=
tokenizer
(
"I"
,
return_tensors
=
"pt"
)[
"input_ids"
]
low_output
=
model
.
generate
(
model_inputs
,
max_new_tokens
=
40
,
num_beams
=
5
,
early_stopping
=
True
,
low_memory
=
True
)
high_output
=
model
.
generate
(
model_inputs
,
max_new_tokens
=
40
,
num_beams
=
5
,
early_stopping
=
True
,
low_memory
=
False
)
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
@
slow
@
slow
def
test_beam_search_example_integration
(
self
):
def
test_beam_search_example_integration
(
self
):
# PT-only test: TF doesn't have a BeamSearchScorer
# PT-only test: TF doesn't have a BeamSearchScorer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment