Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
caf5e369
Unverified
Commit
caf5e369
authored
Jul 20, 2023
by
Benjamin Badger
Committed by
GitHub
Jul 20, 2023
Browse files
Contrastive Search peak memory reduction (#24120)
Co-authored-by:
Joao Gante
<
joaofranciscocardosogante@gmail.com
>
parent
aa1b09c5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
31 deletions
+147
-31
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+4
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+100
-31
tests/generation/test_utils.py
tests/generation/test_utils.py
+43
-0
No files found.
src/transformers/generation/configuration_utils.py
View file @
caf5e369
...
@@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
...
@@ -189,6 +189,9 @@ class GenerationConfig(PushToHubMixin):
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`.
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
Higher guidance scale encourages the model to generate samples that are more closely linked to the input
prompt, usually at the expense of poorer quality.
prompt, usually at the expense of poorer quality.
low_memory (`bool`, *optional*):
Switch to sequential topk for contrastive search to reduce peak memory. Used with contrastive search.
> Parameters that define the output variables of `generate`
> Parameters that define the output variables of `generate`
...
@@ -270,6 +273,7 @@ class GenerationConfig(PushToHubMixin):
...
@@ -270,6 +273,7 @@ class GenerationConfig(PushToHubMixin):
self
.
forced_decoder_ids
=
kwargs
.
pop
(
"forced_decoder_ids"
,
None
)
self
.
forced_decoder_ids
=
kwargs
.
pop
(
"forced_decoder_ids"
,
None
)
self
.
sequence_bias
=
kwargs
.
pop
(
"sequence_bias"
,
None
)
self
.
sequence_bias
=
kwargs
.
pop
(
"sequence_bias"
,
None
)
self
.
guidance_scale
=
kwargs
.
pop
(
"guidance_scale"
,
None
)
self
.
guidance_scale
=
kwargs
.
pop
(
"guidance_scale"
,
None
)
self
.
low_memory
=
kwargs
.
pop
(
"low_memory"
,
None
)
# Parameters that define the output variables of `generate`
# Parameters that define the output variables of `generate`
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
self
.
num_return_sequences
=
kwargs
.
pop
(
"num_return_sequences"
,
1
)
...
...
src/transformers/generation/utils.py
View file @
caf5e369
...
@@ -1569,6 +1569,7 @@ class GenerationMixin:
...
@@ -1569,6 +1569,7 @@ class GenerationMixin:
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
sequential
=
generation_config
.
low_memory
,
**
model_kwargs
,
**
model_kwargs
,
)
)
...
@@ -1832,6 +1833,7 @@ class GenerationMixin:
...
@@ -1832,6 +1833,7 @@ class GenerationMixin:
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
return_dict_in_generate
:
Optional
[
bool
]
=
None
,
synced_gpus
:
bool
=
False
,
synced_gpus
:
bool
=
False
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
streamer
:
Optional
[
"BaseStreamer"
]
=
None
,
sequential
:
Optional
[
bool
]
=
None
,
**
model_kwargs
,
**
model_kwargs
,
)
->
Union
[
ContrastiveSearchOutput
,
torch
.
LongTensor
]:
)
->
Union
[
ContrastiveSearchOutput
,
torch
.
LongTensor
]:
r
"""
r
"""
...
@@ -1882,6 +1884,8 @@ class GenerationMixin:
...
@@ -1882,6 +1884,8 @@ class GenerationMixin:
streamer (`BaseStreamer`, *optional*):
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
sequential (`bool`, *optional*):
Switches topk hidden state computation from parallel to sequential to reduce memory if True.
model_kwargs:
model_kwargs:
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
...
@@ -1921,6 +1925,7 @@ class GenerationMixin:
...
@@ -1921,6 +1925,7 @@ class GenerationMixin:
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
stopping_criteria
=
stopping_criteria
if
stopping_criteria
is
not
None
else
StoppingCriteriaList
()
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_config
.
eos_token_id
sequential
=
sequential
if
sequential
is
not
None
else
self
.
generation_config
.
low_memory
if
isinstance
(
eos_token_id
,
int
):
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
eos_token_id
=
[
eos_token_id
]
eos_token_id_tensor
=
torch
.
tensor
(
eos_token_id
).
to
(
input_ids
.
device
)
if
eos_token_id
is
not
None
else
None
eos_token_id_tensor
=
torch
.
tensor
(
eos_token_id
).
to
(
input_ids
.
device
)
if
eos_token_id
is
not
None
else
None
...
@@ -1986,6 +1991,7 @@ class GenerationMixin:
...
@@ -1986,6 +1991,7 @@ class GenerationMixin:
last_hidden_states
=
outputs
.
decoder_hidden_states
[
-
1
]
last_hidden_states
=
outputs
.
decoder_hidden_states
[
-
1
]
else
:
else
:
last_hidden_states
=
outputs
.
hidden_states
[
-
1
]
last_hidden_states
=
outputs
.
hidden_states
[
-
1
]
# next logit for contrastive search to select top-k candidate tokens
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step
=
outputs
.
logits
[:,
-
1
,
:]
logit_for_next_step
=
outputs
.
logits
[:,
-
1
,
:]
...
@@ -1995,11 +2001,11 @@ class GenerationMixin:
...
@@ -1995,11 +2001,11 @@ class GenerationMixin:
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
standardize_cache_format
=
True
,
standardize_cache_format
=
True
,
)
)
if
not
sequential
:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
_
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
expand_size
=
top_k
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
**
model_kwargs
expand_size
=
top_k
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
**
model_kwargs
)
)
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
)
past_key_values
=
model_kwargs
.
get
(
"past_key_values"
)
if
past_key_values
is
None
:
if
past_key_values
is
None
:
...
@@ -2019,7 +2025,6 @@ class GenerationMixin:
...
@@ -2019,7 +2025,6 @@ class GenerationMixin:
# contrastive_search main logic start:
# contrastive_search main logic start:
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
# degeneration penalty
# degeneration penalty
logit_for_next_step
=
logits_processor
(
input_ids
,
logit_for_next_step
)
logit_for_next_step
=
logits_processor
(
input_ids
,
logit_for_next_step
)
logit_for_next_step
=
logits_warper
(
input_ids
,
logit_for_next_step
)
logit_for_next_step
=
logits_warper
(
input_ids
,
logit_for_next_step
)
next_probs
=
nn
.
functional
.
softmax
(
logit_for_next_step
,
dim
=-
1
)
next_probs
=
nn
.
functional
.
softmax
(
logit_for_next_step
,
dim
=-
1
)
...
@@ -2049,25 +2054,74 @@ class GenerationMixin:
...
@@ -2049,25 +2054,74 @@ class GenerationMixin:
items
=
[]
items
=
[]
# item is either the key or the value matrix
# item is either the key or the value matrix
for
item
in
layer
:
for
item
in
layer
:
items
.
append
(
item
.
repeat_interleave
(
top_k
,
dim
=
0
))
if
sequential
:
items
.
append
(
item
.
repeat_interleave
(
1
,
dim
=
0
))
else
:
items
.
append
(
item
.
repeat_interleave
(
top_k
,
dim
=
0
))
new_key_values
.
append
(
items
)
new_key_values
.
append
(
items
)
model_kwargs
[
"past_key_values"
]
=
new_key_values
model_kwargs
[
"past_key_values"
]
=
new_key_values
# compute the candidate tokens by the language model and collects their hidden_states
if
sequential
:
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
.
view
(
-
1
,
1
),
**
model_kwargs
)
all_outputs
=
{
key
:
[]
for
key
in
outputs
}
# defined in first loop iteration
outputs
=
self
(
all_last_hstates
,
all_hstates
,
all_logits
=
[],
[],
[]
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
for
i
in
range
(
top_k
):
)
# compute the candidate tokens by the language model and collect their hidden_states
next_past_key_values
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
True
)
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
[:,
i
].
view
(
-
1
,
1
),
**
model_kwargs
)
outputs
=
self
(
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
,
)
for
key
in
all_outputs
:
all_outputs
[
key
].
append
(
outputs
[
key
])
if
self
.
config
.
is_encoder_decoder
:
next_hidden
=
outputs
.
decoder_hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
decoder_hidden_states
else
:
next_hidden
=
outputs
.
hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
hidden_states
all_last_hstates
.
append
(
torch
.
squeeze
(
next_hidden
,
0
))
all_hstates
.
append
(
full_hidden_states
)
all_logits
.
append
(
outputs
.
logits
[:,
-
1
,
:])
# stack hidden states
next_hidden
=
torch
.
stack
([
all_last_hstates
[
i
]
for
i
in
range
(
top_k
)],
dim
=
0
)
final_full_hstates
=
[
0
for
i
in
range
(
len
(
full_hidden_states
))]
for
layer
in
range
(
len
(
full_hidden_states
)):
final_full_hstates
[
layer
]
=
torch
.
stack
(
[
torch
.
squeeze
(
all_hstates
[
i
][
layer
],
0
)
for
i
in
range
(
top_k
)],
dim
=
0
)
full_hidden_states
=
tuple
(
final_full_hstates
)
# stack logits
logits
=
torch
.
cat
(
all_logits
,
dim
=
0
)
logits
=
outputs
.
logits
[:,
-
1
,
:]
# name is different for encoder-decoder and decoder-only models
if
self
.
config
.
is_encoder_decoder
:
next_hidden
=
outputs
.
decoder_hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
decoder_hidden_states
else
:
else
:
next_hidden
=
outputs
.
hidden_states
[
-
1
]
# compute the candidate tokens by the language model and collect their hidden_states
full_hidden_states
=
outputs
.
hidden_states
# assembles top_k_ids into batch of size k
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
top_k_ids
.
view
(
-
1
,
1
),
**
model_kwargs
)
outputs
=
self
(
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
,
)
# name is different for encoder-decoder and decoder-only models
if
self
.
config
.
is_encoder_decoder
:
next_hidden
=
outputs
.
decoder_hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
decoder_hidden_states
else
:
next_hidden
=
outputs
.
hidden_states
[
-
1
]
full_hidden_states
=
outputs
.
hidden_states
logits
=
outputs
.
logits
[:,
-
1
,
:]
context_hidden
=
last_hidden_states
.
repeat_interleave
(
top_k
,
dim
=
0
)
context_hidden
=
last_hidden_states
.
repeat_interleave
(
top_k
,
dim
=
0
)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
...
@@ -2089,17 +2143,32 @@ class GenerationMixin:
...
@@ -2089,17 +2143,32 @@ class GenerationMixin:
layer
=
torch
.
stack
(
torch
.
split
(
layer
,
top_k
))[
range
(
batch_size
),
selected_idx
,
:]
layer
=
torch
.
stack
(
torch
.
split
(
layer
,
top_k
))[
range
(
batch_size
),
selected_idx
,
:]
next_decoder_hidden_states
+=
(
layer
,)
next_decoder_hidden_states
+=
(
layer
,)
# select the past_key_value
# generate past_key_values cache of only the selected token
new_key_values
=
()
if
sequential
:
for
layer
in
next_past_key_values
:
next_model_input
=
self
.
prepare_inputs_for_generation
(
items
=
()
top_k_ids
[:,
selected_idx
].
view
(
-
1
,
1
),
**
model_kwargs
# item is either the key or the value matrix
)
for
item
in
layer
:
item
=
torch
.
stack
(
torch
.
split
(
item
,
top_k
,
dim
=
0
))
# [B, K, num_head, seq_len, esz]
selected_outputs
=
self
(
item
=
item
[
range
(
batch_size
),
selected_idx
,
...]
# [B, num_head, seq_len, esz]
**
next_model_input
,
items
+=
(
item
,)
return_dict
=
True
,
new_key_values
+=
(
items
,)
output_hidden_states
=
False
,
next_past_key_values
=
new_key_values
output_attentions
=
False
,
)
next_past_key_values
=
selected_outputs
[
"past_key_values"
]
else
:
next_past_key_values
=
self
.
_extract_past_from_model_output
(
outputs
,
standardize_cache_format
=
True
)
new_key_values
=
()
for
layer
in
next_past_key_values
:
items
=
()
# item is either the key or the value matrix
for
item
in
layer
:
item
=
torch
.
stack
(
torch
.
split
(
item
,
top_k
,
dim
=
0
))
# [B, K, num_head, seq_len, esz]
item
=
item
[
range
(
batch_size
),
selected_idx
,
...]
# [B, num_head, seq_len, esz]
items
+=
(
item
,)
new_key_values
+=
(
items
,)
next_past_key_values
=
new_key_values
logit_for_next_step
=
torch
.
stack
(
torch
.
split
(
logits
,
top_k
))[
range
(
batch_size
),
selected_idx
,
:]
logit_for_next_step
=
torch
.
stack
(
torch
.
split
(
logits
,
top_k
))[
range
(
batch_size
),
selected_idx
,
:]
...
...
tests/generation/test_utils.py
View file @
caf5e369
...
@@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
...
@@ -1457,6 +1457,49 @@ class GenerationTesterMixin:
for
output
in
(
output_contrastive
,
output_generate
):
for
output
in
(
output_contrastive
,
output_generate
):
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
def
test_contrastive_generate_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT, Reformer, gptbigcode, and speech2text have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"gptbigcode"
,
"speech2text"
]
):
return
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
# NOTE: contrastive search only works with cache on at the moment.
if
not
hasattr
(
config
,
"use_cache"
):
return
config
.
use_cache
=
True
config
.
is_decoder
=
True
# test output equality of low versus high memory
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
low_output
=
model
.
generate
(
input_ids
,
top_k
=
4
,
penalty_alpha
=
0.6
,
low_memory
=
True
,
max_length
=
max_length
,
attention_mask
=
attention_mask
,
)
high_output
=
model
.
generate
(
input_ids
,
top_k
=
4
,
penalty_alpha
=
0.6
,
low_memory
=
False
,
max_length
=
max_length
,
attention_mask
=
attention_mask
,
)
self
.
assertListEqual
(
low_output
.
tolist
(),
high_output
.
tolist
())
return
@
slow
# TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
@
slow
# TODO(Joao): remove this. Some models (e.g. data2vec, xcom, roberta) have an error rate between 1 and 10%.
def
test_assisted_decoding_matches_greedy_search
(
self
):
def
test_assisted_decoding_matches_greedy_search
(
self
):
# This test ensures that the assisted generation does not introduce output changes over greedy search.
# This test ensures that the assisted generation does not introduce output changes over greedy search.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment