Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
caf5e369
Unverified
Commit
caf5e369
authored
Jul 20, 2023
by
Benjamin Badger
Committed by
GitHub
Jul 20, 2023
Browse files
Contrastive Search peak memory reduction (#24120)
Co-authored-by:
Joao Gante
<
joaofranciscocardosogante@gmail.com
>
parent
aa1b09c5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
31 deletions
+147
-31
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+4
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+100
-31
tests/generation/test_utils.py
tests/generation/test_utils.py
+43
-0
No files found.
src/transformers/generation/configuration_utils.py
View file @
caf5e369
...
...
@@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
> Parameters that define the output variables of `generate`
...
...
@@ -270,6 +273,7 @@ class GenerationConfig(PushToHubMixin):
self
.
forced_decoder_ids
=
kwargs
.
pop
(
"forced_decoder_ids"
,
None
)
self
.
sequence_bias
=
kwargs
.
pop
(
"sequence_bias"
,
None
)
self
.
guidance_scale
=
kwargs
.
pop
(
"guidance_scale"
,
None
)
self
.
low_memory
=
kwargs
.
pop
(
"low_memory"
,
None
)
# Parameters that define the output variables of `generate`
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
...
...
src/transformers/generation/utils.py
View file @
caf5e369
...
...
@@ -1569,6 +1569,7 @@ class GenerationMixin:
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
sequential
=
generation_config
.
low_memory
,
**
model_kwargs
,
)
...
...
@@ -1832,6 +1833,7 @@ class GenerationMixin:
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
synced_gpus
:
bool
=
False
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
sequential
:
Optional
[
bool
]
=
None
,
**
model_kwargs
,
)
->
Union
[
ContrastiveSearchOutput
,
torch
.
LongTensor
]:
r
"""
...
...
@@ -1882,6 +1884,8 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
sequential (`bool`, *optional*):
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
...
...
@@ -1921,6 +1925,7 @@ class GenerationMixin:
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_config
.
eos_token_id
sequential
=
sequential
if
sequential
is
not
None
else
self
.
generation_config
.
low_memory
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
eos_token_id_tensor
=
torch
.
tensor
(
eos_token_id
).
to
(
input_ids
.
device
)
if
eos_token_id
is
not
None
else
None
...
...
@@ -1986,6 +1991,7 @@ class GenerationMixin:
last_hidden_states
=
outputs
.
decoder_hidden_states
[
-
1
]
else
:
last_hidden_states
=
outputs
.
hidden_states
[
-
1
]
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step
=
outputs
.
logits
[:,
-
1
,
:]
...
...
@@ -1995,7 +2001,7 @@ class GenerationMixin:
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
standardize_cache_format
=
True
,
)
if
not
sequential
:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
expand_size
=
top_k
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
**
model_kwargs
...
...
@@ -2019,7 +2025,6 @@ class GenerationMixin:
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty
logit_for_next_step
=
logits_processor
(
input_ids
,
logit_for_next_step
)
logit_for_next_step
=
logits_warper
(
input_ids
,
logit_for_next_step
)
next_probs
=
nn
.
functional
.
softmax
(
logit_for_next_step
,
dim
=-
1
)
...
...
@@ -2049,18 +2054,64 @@ class GenerationMixin:
items
=
[]
# item is either the key or the value matrix
for
item
in
layer
:
if
sequential
:
items
.
append
(
item
.
repeat_interleave
(
1
,
dim
=
0
))
else
:
items
.
append
(
item
.
repeat_interleave
(
top_k
,
dim
=
0
))
new_key_values
.
append
(
items
)
model_kwargs
[
"past_key_values"
]
=
new_key_values
# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
.
view
(
-
1
,
1
),
**
model_kwargs
)
if
sequential
:
all_outputs
=
{
key
:
[]
for
key
in
outputs
}
# defined in first loop iteration
all_last_hstates
,
all_hstates
,
all_logits
=
[],
[],
[]
for
i
in
range
(
top_k
):
# compute the candidate tokens by the language model and collect their hidden_states
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
[:,
i
].
view
(
-
1
,
1
),
**
model_kwargs
)
outputs
=
self
(
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
,
)
next_past_key_values
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
True
)
for
key
in
all_outputs
:
all_outputs
[
key
].
append
(
outputs
[
key
])
logits
=
outputs
.
logits
[:,
-
1
,
:]
if
self
.
config
.
is_encoder_decoder
:
next_hidden
=
outputs
.
decoder_hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
decoder_hidden_states
else
:
next_hidden
=
outputs
.
hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
hidden_states
all_last_hstates
.
append
(
torch
.
squeeze
(
next_hidden
,
0
))
all_hstates
.
append
(
full_hidden_states
)
all_logits
.
append
(
outputs
.
logits
[:,
-
1
,
:])
# stack hidden states
next_hidden
=
torch
.
stack
([
all_last_hstates
[
i
]
for
i
in
range
(
top_k
)],
dim
=
0
)
final_full_hstates
=
[
0
for
i
in
range
(
len
(
full_hidden_states
))]
for
layer
in
range
(
len
(
full_hidden_states
)):
final_full_hstates
[
layer
]
=
torch
.
stack
(
[
torch
.
squeeze
(
all_hstates
[
i
][
layer
],
0
)
for
i
in
range
(
top_k
)],
dim
=
0
)
full_hidden_states
=
tuple
(
final_full_hstates
)
# stack logits
logits
=
torch
.
cat
(
all_logits
,
dim
=
0
)
else
:
# compute the candidate tokens by the language model and collect their hidden_states
# assembles top_k_ids into batch of size k
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
.
view
(
-
1
,
1
),
**
model_kwargs
)
outputs
=
self
(
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
,
)
# name is different for encoder-decoder and decoder-only models
if
self
.
config
.
is_encoder_decoder
:
next_hidden
=
outputs
.
decoder_hidden_states
[
-
1
]
...
...
@@ -2068,6 +2119,9 @@ class GenerationMixin:
else
:
next_hidden
=
outputs
.
hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
hidden_states
logits
=
outputs
.
logits
[:,
-
1
,
:]
context_hidden
=
last_hidden_states
.
repeat_interleave
(
top_k
,
dim
=
0
)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
...
...
@@ -2089,7 +2143,22 @@ class GenerationMixin:
layer
=
torch
.
stack
(
torch
.
split
(
layer
,
top_k
))[
range
(
batch_size
),
selected_idx
,
:]
next_decoder_hidden_states
+=
(
layer
,)
# select the past_key_value
# generate past_key_values cache of only the selected token
if
sequential
:
next_model_input
=
self
.
prepare_inputs_for_generation
(
top_k_ids
[:,
selected_idx
].
view
(
-
1
,
1
),
**
model_kwargs
)
selected_outputs
=
self
(
**
next_model_input
,
return_dict
=
True
,
output_hidden_states
=
False
,
output_attentions
=
False
,
)
next_past_key_values
=
selected_outputs
[
"past_key_values"
]
else
:
next_past_key_values
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
True
)
new_key_values
=
()
for
layer
in
next_past_key_values
:
items
=
()
...
...
tests/generation/test_utils.py
View file @
caf5e369
...
...
@@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
for
output
in
(
output_contrastive
,
output_generate
):
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
def
test_contrastive_generate_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"gptbigcode"
,
"speech2text"
]
):
return
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
# NOTE: contrastive search only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
return
config
.
use_cache
=
True
config
.
is_decoder
=
True
# test output equality of low versus high memory
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
low_output
=
model
.
generate
(
input_ids
,
top_k
=
4
,
penalty_alpha
=
0.6
,
low_memory
=
True
,
max_length
=
max_length
,
attention_mask
=
attention_mask
,
)
high_output
=
model
.
generate
(
input_ids
,
top_k
=
4
,
penalty_alpha
=
0.6
,
low_memory
=
False
,
max_length
=
max_length
,
attention_mask
=
attention_mask
,
)
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
return
@
slow
# TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def
test_assisted_decoding_matches_greedy_search
(
self
):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment