Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
849367cc
Unverified
Commit
849367cc
authored
Apr 29, 2023
by
Joao Gante
Committed by
GitHub
Apr 29, 2023
Browse files
Generate: prepare assisted generation for release (#23052)
parent
dfeb5aa6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
76 deletions
+63
-76
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+61
-74
tests/generation/test_utils.py
tests/generation/test_utils.py
+2
-2
No files found.
src/transformers/generation/utils.py
View file @
849367cc
...
@@ -4177,7 +4177,6 @@ class GenerationMixin:
...
@@ -4177,7 +4177,6 @@ class GenerationMixin:
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
["It might be possible to get a better understanding of the nature of the problem, but it's not"]
```"""
```"""
# NOTE: the code here is copy/paste from greedy search/sample, except when clearly stated in the comments
# Assistant: initialize assistant-related variables
# Assistant: initialize assistant-related variables
if
not
hasattr
(
assistant_model
,
"max_assistant_tokens"
):
if
not
hasattr
(
assistant_model
,
"max_assistant_tokens"
):
assistant_model
.
max_assistant_tokens
=
5
# this value, which will be updated, persists across calls
assistant_model
.
max_assistant_tokens
=
5
# this value, which will be updated, persists across calls
...
@@ -4248,20 +4247,20 @@ class GenerationMixin:
...
@@ -4248,20 +4247,20 @@ class GenerationMixin:
prev_seq_len
=
model_kwargs
[
"assistant_past_key_values"
][
0
][
0
].
shape
[
2
]
prev_seq_len
=
model_kwargs
[
"assistant_past_key_values"
][
0
][
0
].
shape
[
2
]
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
# `new_token_len` can be 1 or 2 (next token in assistant + last token picked by the larger model)
new_token_len
=
candidate_input_ids
.
shape
[
1
]
-
prev_seq_len
new_token_len
=
candidate_input_ids
.
shape
[
1
]
-
prev_seq_len
tmp
_inputs
=
candidate_input_ids
[:,
-
new_token_len
:]
assist
_inputs
=
candidate_input_ids
[:,
-
new_token_len
:]
tmp
_attn
=
torch
.
ones_like
(
candidate_input_ids
)
assist
_attn
=
torch
.
ones_like
(
candidate_input_ids
)
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
# TODO (joao): make it compatible with models that use unconventional fwd pass logic, like blip2
if
assistant_model
.
config
.
is_encoder_decoder
:
if
assistant_model
.
config
.
is_encoder_decoder
:
assistant_model_outputs
=
assistant_model
(
assistant_model_outputs
=
assistant_model
(
decoder_input_ids
=
tmp
_inputs
,
decoder_input_ids
=
assist
_inputs
,
decoder_attention_mask
=
tmp
_attn
,
decoder_attention_mask
=
assist
_attn
,
past_key_values
=
model_kwargs
[
"assistant_past_key_values"
],
past_key_values
=
model_kwargs
[
"assistant_past_key_values"
],
encoder_outputs
=
model_kwargs
[
"assistant_encoder_outputs"
],
encoder_outputs
=
model_kwargs
[
"assistant_encoder_outputs"
],
)
)
else
:
else
:
assistant_model_outputs
=
assistant_model
(
assistant_model_outputs
=
assistant_model
(
tmp
_inputs
,
assist
_inputs
,
attention_mask
=
tmp
_attn
,
attention_mask
=
assist
_attn
,
past_key_values
=
model_kwargs
[
"assistant_past_key_values"
],
past_key_values
=
model_kwargs
[
"assistant_past_key_values"
],
)
)
else
:
else
:
...
@@ -4296,16 +4295,17 @@ class GenerationMixin:
...
@@ -4296,16 +4295,17 @@ class GenerationMixin:
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
candidate_length
=
candidate_input_ids
.
shape
[
1
]
-
input_ids
.
shape
[
1
]
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
# `candidate_length + 1` relevant logits from this process (see step 7 on why the +1)
# `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Run a forward pass on the candidate sequence
# 2.1. Run a forward pass on the candidate sequence
if
"past_key_values"
in
model_kwargs
:
if
"past_key_values"
in
model_kwargs
:
og_
model_attn
=
torch
.
ones_like
(
candidate_input_ids
)
model_attn
=
torch
.
ones_like
(
candidate_input_ids
)
og_
model_input_ids
=
candidate_input_ids
[:,
-
candidate_length
-
1
:]
model_input_ids
=
candidate_input_ids
[:,
-
candidate_length
-
1
:]
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
outputs
=
self
(
outputs
=
self
(
decoder_input_ids
=
og_
model_input_ids
,
decoder_input_ids
=
model_input_ids
,
decoder_attention_mask
=
og_
model_attn
,
decoder_attention_mask
=
model_attn
,
past_key_values
=
model_kwargs
[
"past_key_values"
],
past_key_values
=
model_kwargs
[
"past_key_values"
],
encoder_outputs
=
model_kwargs
[
"encoder_outputs"
],
encoder_outputs
=
model_kwargs
[
"encoder_outputs"
],
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
...
@@ -4313,8 +4313,8 @@ class GenerationMixin:
...
@@ -4313,8 +4313,8 @@ class GenerationMixin:
)
)
else
:
else
:
outputs
=
self
(
outputs
=
self
(
og_
model_input_ids
,
model_input_ids
,
attention_mask
=
og_
model_attn
,
attention_mask
=
model_attn
,
past_key_values
=
model_kwargs
[
"past_key_values"
],
past_key_values
=
model_kwargs
[
"past_key_values"
],
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
...
@@ -4343,59 +4343,51 @@ class GenerationMixin:
...
@@ -4343,59 +4343,51 @@ class GenerationMixin:
for
i
in
range
(
candidate_length
):
for
i
in
range
(
candidate_length
):
new_logits
[:,
i
,
:]
=
logits_warper
(
candidate_input_ids
[:,
:
cur_len
+
i
],
new_logits
[:,
i
,
:])
new_logits
[:,
i
,
:]
=
logits_warper
(
candidate_input_ids
[:,
:
cur_len
+
i
],
new_logits
[:,
i
,
:])
# 3. Obtain the next tokens from the original model logits. If `do_sample` is True, use multinomial
# 3. Obtain the next tokens from the original model logits.
# sampling, otherwise use argmax.
if
do_sample
:
if
do_sample
:
probs
=
new_logits
[:,
-
candidate_length
-
1
:,
:].
softmax
(
dim
=-
1
)
probs
=
new_logits
[:,
-
candidate_length
-
1
:,
:].
softmax
(
dim
=-
1
)
sampled_tokens
=
torch
.
multinomial
(
probs
[
0
,
:,
:],
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
selected_tokens
=
torch
.
multinomial
(
probs
[
0
,
:,
:],
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
next_tokens
=
sampled_tokens
[:,
:
-
1
]
else
:
else
:
next
_tokens
=
new_logits
[:,
-
candidate_length
-
1
:
-
1
,
:].
argmax
(
dim
=-
1
)
selected
_tokens
=
new_logits
[:,
-
candidate_length
-
1
:,
:].
argmax
(
dim
=-
1
)
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens
=
candidate_input_ids
[:,
-
candidate_length
:]
candidate_new_tokens
=
candidate_input_ids
[:,
-
candidate_length
:]
n_matches
=
((
~
(
candidate_new_tokens
==
next_tokens
)).
cumsum
(
dim
=-
1
)
<
1
).
sum
()
n_matches
=
((
~
(
candidate_new_tokens
==
selected_tokens
[:,
:
-
1
]
)).
cumsum
(
dim
=-
1
)
<
1
).
sum
()
# 5. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# cost of forecasting incorrect assistant tokens.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
if
n_matches
==
int
(
assistant_model
.
max_assistant_tokens
):
# is no match.
assistant_model
.
max_assistant_tokens
+=
2.0
else
:
assistant_model
.
max_assistant_tokens
=
max
(
1.0
,
assistant_model
.
max_assistant_tokens
-
1.0
)
# 6. Update variables according to the number of matching assistant tokens.
# 5.1. Ensure we don't generate beyond max_len or an EOS token
# 6.1. Ensure we don't generate beyond max_len or an EOS token (remember: one token will be added below)
n_matches
=
min
(
n_matches
,
max_len
-
cur_len
-
1
)
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
n_matches
-=
1
n_matches
-=
1
input_ids
=
candidate_input_ids
[:,
0
:
cur_len
+
n_matches
]
n_matches
=
min
(
n_matches
,
max_len
-
cur_len
-
1
)
new_cur_len
=
input_ids
.
shape
[
-
1
]
# 5.2. Get the valid continuation, after the matching tokens
valid_tokens
=
selected_tokens
[:,
:
n_matches
+
1
]
input_ids
=
torch
.
cat
((
input_ids
,
valid_tokens
),
dim
=-
1
)
if
streamer
is
not
None
:
if
streamer
is
not
None
:
streamer
.
put
(
candidate_input_ids
[:,
cur_len
:
cur_len
+
n_matches
])
streamer
.
put
(
valid_tokens
.
cpu
())
new_cur_len
=
input_ids
.
shape
[
-
1
]
# 6.2. Discard past key values relative to unused assistant tokens
# 5.3. Discard past key values relative to unused assistant tokens
outputs
.
past_key_values
=
_crop_past_key_values
(
self
,
outputs
.
past_key_values
,
new_cur_len
)
new_cache_size
=
new_cur_len
-
1
outputs
.
past_key_values
=
_crop_past_key_values
(
self
,
outputs
.
past_key_values
,
new_cache_size
)
model_kwargs
[
"assistant_past_key_values"
]
=
_crop_past_key_values
(
model_kwargs
[
"assistant_past_key_values"
]
=
_crop_past_key_values
(
assistant_model
,
model_kwargs
[
"assistant_past_key_values"
],
new_c
ur_len
assistant_model
,
model_kwargs
[
"assistant_past_key_values"
],
new_c
ache_size
-
1
)
)
# the assistant does not have the token after the last match, hence the -1
# 6.3. Extract the logits for the next token
# 6. Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
next_token_scores
=
new_logits
[:,
n_matches
,
:]
# probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
# cost of forecasting incorrect assistant tokens.
# 7. Use the set of logits after the last matching assistant token to obtain the next token. Note that,
if
n_matches
==
int
(
assistant_model
.
max_assistant_tokens
):
# because of this step, assisted generation search reduces to a normal greedy search/sample if there is no
assistant_model
.
max_assistant_tokens
+=
2.0
# match.
if
do_sample
:
probs
=
probs
[:,
n_matches
,
:]
next_tokens
=
sampled_tokens
[:,
n_matches
]
else
:
else
:
next_tokens
=
torch
.
argmax
(
next_token_scores
,
dim
=-
1
)
assistant_model
.
max_assistant_tokens
=
max
(
1.0
,
assistant_model
.
max_assistant_tokens
-
1.0
)
# Assistant: main logic end; Compared to greedy search/sample, the following (redundant) blocks were
# Assistant: main logic end
# removed below: (1) model input preparation; (2) model forward pass; (3) score preparation; (4) model
# cache update.
if
synced_gpus
and
this_peer_finished
:
if
synced_gpus
and
this_peer_finished
:
continue
# don't waste resources running the code we don't need
continue
# don't waste resources running the code we don't need
...
@@ -4407,20 +4399,20 @@ class GenerationMixin:
...
@@ -4407,20 +4399,20 @@ class GenerationMixin:
scores
+=
tuple
(
new_logits
[:,
i
,
:]
for
i
in
range
(
n_matches
+
1
))
scores
+=
tuple
(
new_logits
[:,
i
,
:]
for
i
in
range
(
n_matches
+
1
))
if
"past_key_values"
not
in
model_kwargs
:
if
"past_key_values"
not
in
model_kwargs
:
last_matching_idx
=
new_cur_len
-
1
added_len
=
new_cur_len
else
:
else
:
last_matching_idx
=
n_matches
added_len
=
n_matches
+
1
if
output_attentions
:
if
output_attentions
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
cross_attentions
=
_split_model_outputs
(
cross_attentions
=
_split_model_outputs
(
cross_attentions
,
outputs
.
cross_attentions
,
cur_len
,
last_matching_idx
cross_attentions
,
outputs
.
cross_attentions
,
cur_len
,
added_len
)
)
decoder_attentions
=
_split_model_outputs
(
decoder_attentions
=
_split_model_outputs
(
decoder_attentions
,
decoder_attentions
,
outputs
.
decoder_attentions
,
outputs
.
decoder_attentions
,
cur_len
,
cur_len
,
last_matching_idx
,
added_len
,
is_decoder_attention
=
True
,
is_decoder_attention
=
True
,
)
)
else
:
else
:
...
@@ -4428,28 +4420,19 @@ class GenerationMixin:
...
@@ -4428,28 +4420,19 @@ class GenerationMixin:
decoder_attentions
,
decoder_attentions
,
outputs
.
attentions
,
outputs
.
attentions
,
cur_len
,
cur_len
,
last_matching_idx
,
added_len
,
is_decoder_attention
=
True
,
is_decoder_attention
=
True
,
)
)
if
output_hidden_states
:
if
output_hidden_states
:
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
decoder_hidden_states
=
_split_model_outputs
(
decoder_hidden_states
=
_split_model_outputs
(
decoder_hidden_states
,
outputs
.
decoder_hidden_states
,
cur_len
,
last_matching_idx
decoder_hidden_states
,
outputs
.
decoder_hidden_states
,
cur_len
,
added_len
)
)
else
:
else
:
decoder_hidden_states
=
_split_model_outputs
(
decoder_hidden_states
=
_split_model_outputs
(
decoder_hidden_states
,
outputs
.
hidden_states
,
cur_len
,
last_matching_idx
decoder_hidden_states
,
outputs
.
hidden_states
,
cur_len
,
added_len
)
)
# finished sentences should have their next token be a padding token
if
eos_token_id
is
not
None
:
next_tokens
=
next_tokens
*
unfinished_sequences
+
pad_token_id
*
(
1
-
unfinished_sequences
)
# update generated ids, model inputs, and length for next step
input_ids
=
torch
.
cat
([
input_ids
,
next_tokens
[:,
None
]],
dim
=-
1
)
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
)
)
...
@@ -4457,7 +4440,10 @@ class GenerationMixin:
...
@@ -4457,7 +4440,10 @@ class GenerationMixin:
# if eos_token was found in one sentence, set sentence to finished
# if eos_token was found in one sentence, set sentence to finished
if
eos_token_id_tensor
is
not
None
:
if
eos_token_id_tensor
is
not
None
:
unfinished_sequences
=
unfinished_sequences
.
mul
(
unfinished_sequences
=
unfinished_sequences
.
mul
(
next_tokens
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
).
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
)).
prod
(
dim
=
0
)
input_ids
[:,
-
1
]
.
tile
(
eos_token_id_tensor
.
shape
[
0
],
1
)
.
ne
(
eos_token_id_tensor
.
unsqueeze
(
1
))
.
prod
(
dim
=
0
)
)
)
# stop when each sentence is finished
# stop when each sentence is finished
...
@@ -4531,7 +4517,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
...
@@ -4531,7 +4517,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
return
past_key_values
return
past_key_values
def
_split_model_outputs
(
outputs
,
new_outputs
,
previous_cur_len
,
last_matching_idx
,
is_decoder_attention
=
False
):
def
_split_model_outputs
(
outputs
,
new_outputs
,
cur_len
,
added_len
,
is_decoder_attention
=
False
):
"""
"""
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
Given the (decoder/cross attentions)/(decoder hidden states) for multiple generated tokens, splits it into a tuple
where each member corresponds to a single generated token.
where each member corresponds to a single generated token.
...
@@ -4541,16 +4527,17 @@ def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_i
...
@@ -4541,16 +4527,17 @@ def _split_model_outputs(outputs, new_outputs, previous_cur_len, last_matching_i
if
len
(
outputs
)
==
0
:
if
len
(
outputs
)
==
0
:
new_tuple
=
()
new_tuple
=
()
for
layer
in
new_outputs
:
for
layer
in
new_outputs
:
last_dim_size
=
previous_
cur_len
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
last_dim_size
=
cur_len
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
new_tuple
+=
(
layer
[...,
:
previous_
cur_len
,
:
last_dim_size
],)
new_tuple
+=
(
layer
[...,
:
cur_len
,
:
last_dim_size
],)
outputs
+=
(
new_tuple
,)
outputs
+=
(
new_tuple
,)
last_matching_idx
-=
previous_cur_len
# The first iteration contains the prompt + 1 generated token, let's update the length variables accordingly
previous_cur_len
+=
1
cur_len
+=
1
added_len
-=
cur_len
for
i
in
range
(
last_matching_idx
+
1
):
for
i
in
range
(
added_len
):
new_tuple
=
()
new_tuple
=
()
for
layer
in
new_outputs
:
for
layer
in
new_outputs
:
last_dim_size
=
previous_
cur_len
+
i
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
last_dim_size
=
cur_len
+
i
if
is_decoder_attention
else
layer
.
shape
[
-
1
]
new_tuple
+=
(
layer
[...,
i
:
i
+
1
,
:
last_dim_size
],)
new_tuple
+=
(
layer
[...,
i
:
i
+
1
,
:
last_dim_size
],)
outputs
+=
(
new_tuple
,)
outputs
+=
(
new_tuple
,)
return
outputs
return
outputs
...
...
tests/generation/test_utils.py
View file @
849367cc
...
@@ -1518,8 +1518,8 @@ class GenerationTesterMixin:
...
@@ -1518,8 +1518,8 @@ class GenerationTesterMixin:
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
self
.
_check_outputs
(
output
,
input_ids
,
model
.
config
,
use_cache
=
True
)
def
test_assisted_decoding_sample
(
self
):
def
test_assisted_decoding_sample
(
self
):
# Seeded assisted decoding will not match sample for the same seed, as the
re are >1 sampling steps per output
# Seeded assisted decoding will not match sample for the same seed, as the
forward pass does not return the
#
token. As such, this test only checks that the output format is correct
.
#
exact same logits (the forward pass of the main model, now with several tokens at once, has causal masking)
.
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
# won't fix: FSMT and Reformer have a different cache variable type (and format).
# won't fix: FSMT and Reformer have a different cache variable type (and format).
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment