Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
45b70384
Unverified
Commit
45b70384
authored
Dec 20, 2023
by
Joao Gante
Committed by
GitHub
Dec 20, 2023
Browse files
Generate: fix speculative decoding (#28166)
Co-authored-by:
Merve Noyan
<
merveenoyan@gmail.com
>
parent
01c081d1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
90 additions
and
72 deletions
+90
-72
docs/source/en/generation_strategies.md
docs/source/en/generation_strategies.md
+12
-9
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+18
-23
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+31
-35
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+4
-4
tests/models/mistral/test_modeling_mistral.py
tests/models/mistral/test_modeling_mistral.py
+25
-1
No files found.
docs/source/en/generation_strategies.md
View file @
45b70384
...
...
@@ -82,7 +82,7 @@ Even if the default decoding strategy mostly works for your task, you can still
commonly adjusted parameters include:
-
`max_new_tokens`
: the maximum number of tokens to generate. In other words, the size of the output sequence, not
including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose
including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose
to stop generation whenever the full generation exceeds some amount of time. To learn more, check [
`StoppingCriteria`
].
-
`num_beams`
: by specifying a number of beams higher than 1, you are effectively switching from greedy search to
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that
...
...
@@ -339,13 +339,16 @@ This guide illustrates the main parameters that enable various decoding strategi
[
`generate`
] method, which gives you even further control over the [
`generate`
] method's behavior.
For the complete list of the available parameters, refer to the
[
API documentation
](
./main_classes/text_generation.md
)
.
###
Assisted
Decoding
###
Speculative
Decoding
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
and sampling are supported with assisted decoding, and doesn't support batched inputs. To learn more about assisted
decoding, check
[
this blog post
](
https://huggingface.co/blog/assisted-generation
)
.
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
`do_sample=True`
, then the token validation with resampling introduced in the
[
speculative decoding paper
](
https://arxiv.org/pdf/2211.17192.pdf
)
is used.
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check
[
this blog post
](
https://huggingface.co/blog/assisted-generation
)
.
To enable assisted decoding, set the
`assistant_model`
argument with a model.
...
...
@@ -366,8 +369,8 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
[
'Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'
]
```
When using assisted decoding with sampling methods, you can use the
`temperature`
argument to control the randomness
just like in multinomial sampling. However, in assisted decoding, reducing the temperature
will
help improv
ing
latency.
When using assisted decoding with sampling methods, you can use the
`temperature`
argument to control the randomness
,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature
may
help improv
e the
latency.
```
python
>>>
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
set_seed
...
...
src/transformers/generation/candidate_generator.py
View file @
45b70384
...
...
@@ -14,14 +14,14 @@
# limitations under the License.
import
copy
import
warnings
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Tuple
import
torch
if
TYPE_CHECKING
:
from
..modeling_utils
import
PreTrainedModel
from
.configuration_utils
import
GenerationConfig
from
.logits_process
import
LogitsProcessorList
...
...
@@ -66,14 +66,17 @@ class CandidateGenerator:
class
AssistedCandidateGenerator
(
CandidateGenerator
):
"""
`CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of
a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
candidates through the use of a smaller model. Read the following blog post for more information:
https://huggingface.co/blog/assisted-generation
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`):
The model to be used for generating candidates. This model should be smaller than the main model.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
...
...
@@ -82,31 +85,20 @@ class AssistedCandidateGenerator(CandidateGenerator):
model as well.
inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
def
__init__
(
self
,
input_ids
:
torch
.
LongTensor
,
assistant_model
:
"PreTrainedModel"
,
generation_config
:
"GenerationConfig"
,
logits_processor
:
"LogitsProcessorList"
,
model_kwargs
:
Dict
,
inputs_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
):
# Prepare the assistant and the starting number of candidate tokens
self
.
assistant_model
=
assistant_model
# Prepare the number of candidate tokens
if
hasattr
(
assistant_model
,
"num_assistant_tokens"
):
warnings
.
warn
(
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be "
"removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead."
,
FutureWarning
,
)
self
.
num_assistant_tokens
=
assistant_model
.
num_assistant_tokens
else
:
self
.
num_assistant_tokens
=
assistant_model
.
generation_config
.
num_assistant_tokens
self
.
num_assistant_tokens
=
assistant_model
.
generation_config
.
num_assistant_tokens
# Prepare the kwargs for the assistant model
assistant_kwargs
=
{}
...
...
@@ -145,13 +137,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
self
.
input_ids_key
=
"input_ids"
self
.
attention_key
=
"attention_mask"
# Prepare other attributes
# Prepare generation-related options.
eos_token_id
=
generation_config
.
eos_token_id
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
self
.
eos_token_id_tensor
=
(
torch
.
tensor
(
eos_token_id
).
to
(
input_ids
.
device
)
if
eos_token_id
is
not
None
else
None
)
self
.
logits_processor
=
logits_processor
self
.
generation_config
=
copy
.
deepcopy
(
generation_config
)
self
.
generation_config
.
return_dict_in_generate
=
True
self
.
generation_config
.
output_scores
=
True
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]:
"""
...
...
@@ -185,12 +181,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs
=
{
self
.
input_ids_key
:
input_ids
,
"do_sample"
:
False
,
"num_beams"
:
1
,
"max_new_tokens"
:
int
(
self
.
num_assistant_tokens
),
"
return_dict_in_generate"
:
True
,
"
output_scores"
:
True
,
"
generation_config"
:
self
.
generation_config
,
"
logits_processor"
:
self
.
logits_processor
,
}
assistant_output
=
self
.
assistant_model
.
generate
(
**
assistant_generation_kwargs
,
**
self
.
assistant_kwargs
)
# 3. Update variables for the next round of candidate generation
...
...
src/transformers/generation/utils.py
View file @
45b70384
...
...
@@ -911,10 +911,10 @@ class GenerationMixin:
candidate_generator
=
AssistedCandidateGenerator
(
input_ids
=
input_ids
,
assistant_model
=
assistant_model
,
generation_config
=
generation_config
,
logits_processor
=
logits_processor
,
model_kwargs
=
model_kwargs
,
inputs_tensor
=
inputs_tensor
,
eos_token_id
=
generation_config
.
eos_token_id
,
)
return
candidate_generator
...
...
@@ -1673,7 +1673,7 @@ class GenerationMixin:
)
# 8. prepare distribution pre_processing samplers
logits_processor
=
self
.
_get_logits_processor
(
prepared_
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_length
,
encoder_input_ids
=
inputs_tensor
,
...
...
@@ -1685,7 +1685,7 @@ class GenerationMixin:
)
# 9. prepare stopping criteria
stopping_criteria
=
self
.
_get_stopping_criteria
(
prepared_
stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
# 10. go into different generation modes
...
...
@@ -1715,9 +1715,9 @@ class GenerationMixin:
input_ids
,
candidate_generator
=
candidate_generator
,
do_sample
=
generation_config
.
do_sample
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -1730,8 +1730,8 @@ class GenerationMixin:
# 11. run greedy search
return
self
.
greedy_search
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -1749,8 +1749,8 @@ class GenerationMixin:
input_ids
,
top_k
=
generation_config
.
top_k
,
penalty_alpha
=
generation_config
.
penalty_alpha
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -1776,9 +1776,9 @@ class GenerationMixin:
# 13. run sample
return
self
.
sample
(
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -1810,8 +1810,8 @@ class GenerationMixin:
return
self
.
beam_search
(
input_ids
,
beam_scorer
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -1847,9 +1847,9 @@ class GenerationMixin:
return
self
.
beam_sample
(
input_ids
,
beam_scorer
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -1881,8 +1881,8 @@ class GenerationMixin:
return
self
.
group_beam_search
(
input_ids
,
beam_scorer
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -1954,8 +1954,8 @@ class GenerationMixin:
return
self
.
constrained_beam_search
(
input_ids
,
constrained_beam_scorer
=
constrained_beam_scorer
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
...
...
@@ -4629,7 +4629,7 @@ class GenerationMixin:
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
max_matches
=
max_len
-
cur_len
-
1
if
do_sample
and
candidate_logits
is
not
None
:
next_sample
d_tokens
,
n_matches
=
_speculative_sampling
(
vali
d_tokens
,
n_matches
=
_speculative_sampling
(
candidate_input_ids
,
candidate_logits
,
candidate_length
,
...
...
@@ -4637,8 +4637,6 @@ class GenerationMixin:
last_assistant_token_is_eos
,
max_matches
,
)
# The selected tokens include the matches plus the next sampled tokens
selected_tokens
=
torch
.
cat
((
candidate_input_ids
[:,
:
n_matches
],
next_sampled_tokens
),
dim
=-
1
)
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
...
...
@@ -4657,6 +4655,7 @@ class GenerationMixin:
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
n_matches
-=
1
n_matches
=
min
(
n_matches
,
max_matches
)
valid_tokens
=
selected_tokens
[:,
:
n_matches
+
1
]
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
...
...
@@ -4664,7 +4663,6 @@ class GenerationMixin:
# is no match.
# 4.1. Get the valid continuation, after the matching tokens
valid_tokens
=
selected_tokens
[:,
:
n_matches
+
1
]
input_ids
=
torch
.
cat
((
input_ids
,
valid_tokens
),
dim
=-
1
)
if
streamer
is
not
None
:
streamer
.
put
(
valid_tokens
.
cpu
())
...
...
@@ -4782,24 +4780,16 @@ def _speculative_sampling(
):
"""
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
the
next
selected token, as well as the number of candidate matches.
the selected token
s
, as well as the number of candidate matches.
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
q
=
candidate_logits
.
softmax
(
dim
=-
1
)
q_i
=
q
[
:,
torch
.
range
(
0
,
candidate_length
-
1
,
dtype
=
torch
.
int
),
candidate_input_ids
[:,
-
candidate_length
:],
].
squeeze
(
0
,
1
)
q_i
=
q
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]].
squeeze
(
0
,
1
)
p
=
new_logits
.
softmax
(
dim
=-
1
)
p_i
=
p
[
:,
torch
.
range
(
0
,
candidate_length
-
1
,
dtype
=
torch
.
int
),
candidate_input_ids
[:,
-
candidate_length
:],
].
squeeze
(
0
,
1
)
p_i
=
p
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]].
squeeze
(
0
,
1
)
probability_ratio
=
p_i
/
q_i
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
...
...
@@ -4824,7 +4814,13 @@ def _speculative_sampling(
p_prime
=
p_n_plus_1
t
=
torch
.
multinomial
(
p_prime
,
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
return
t
,
n_matches
# The selected tokens include the matches (if any) plus the next sampled tokens
if
n_matches
>
0
:
valid_tokens
=
torch
.
cat
((
candidate_input_ids
[:,
-
n_matches
:],
t
),
dim
=-
1
)
else
:
valid_tokens
=
t
return
valid_tokens
,
n_matches
def
_split_model_outputs
(
outputs
,
new_outputs
,
cur_len
,
added_len
,
is_decoder_attention
=
False
):
...
...
src/transformers/models/whisper/modeling_whisper.py
View file @
45b70384
...
...
@@ -2037,15 +2037,15 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
FutureWarning
,
)
if
generation_config
is
None
:
generation_config
=
copy
.
deepcopy
(
self
.
generation_config
)
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
generation_config
.
return_dict_in_generate
else
generation_config
.
return_dict_in_generate
)
if
generation_config
is
None
:
generation_config
=
copy
.
deepcopy
(
self
.
generation_config
)
input_stride
=
self
.
model
.
encoder
.
conv1
.
stride
[
0
]
*
self
.
model
.
encoder
.
conv2
.
stride
[
0
]
if
num_segment_frames
is
None
:
num_segment_frames
=
input_stride
*
self
.
config
.
max_source_positions
...
...
tests/models/mistral/test_modeling_mistral.py
View file @
45b70384
...
...
@@ -21,7 +21,7 @@ import unittest
import
pytest
from
transformers
import
AutoTokenizer
,
MistralConfig
,
is_torch_available
from
transformers
import
AutoTokenizer
,
MistralConfig
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
(
backend_empty_cache
,
require_bitsandbytes
,
...
...
@@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
del
model
backend_empty_cache
(
torch_device
)
gc
.
collect
()
@
slow
def
test_speculative_generation
(
self
):
EXPECTED_TEXT_COMPLETION
=
(
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
)
prompt
=
"My favourite condiment is "
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"mistralai/Mistral-7B-v0.1"
,
use_fast
=
False
)
model
=
MistralForCausalLM
.
from_pretrained
(
"mistralai/Mistral-7B-v0.1"
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
)
input_ids
=
tokenizer
.
encode
(
prompt
,
return_tensors
=
"pt"
).
to
(
model
.
model
.
embed_tokens
.
weight
.
device
)
# greedy generation outputs
set_seed
(
0
)
generated_ids
=
model
.
generate
(
input_ids
,
max_new_tokens
=
20
,
do_sample
=
True
,
temperature
=
0.3
,
assistant_model
=
model
)
text
=
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
)
self
.
assertEqual
(
EXPECTED_TEXT_COMPLETION
,
text
)
del
model
backend_empty_cache
(
torch_device
)
gc
.
collect
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment