Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
45b70384
Unverified
Commit
45b70384
authored
Dec 20, 2023
by
Joao Gante
Committed by
GitHub
Dec 20, 2023
Browse files
Generate: fix speculative decoding (#28166)
Co-authored-by:
Merve Noyan
<
merveenoyan@gmail.com
>
parent
01c081d1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
90 additions
and
72 deletions
+90
-72
docs/source/en/generation_strategies.md
docs/source/en/generation_strategies.md
+12
-9
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+18
-23
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+31
-35
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+4
-4
tests/models/mistral/test_modeling_mistral.py
tests/models/mistral/test_modeling_mistral.py
+25
-1
No files found.
docs/source/en/generation_strategies.md
View file @
45b70384
...
@@ -82,7 +82,7 @@ Even if the default decoding strategy mostly works for your task, you can still
...
@@ -82,7 +82,7 @@ Even if the default decoding strategy mostly works for your task, you can still
commonly adjusted parameters include:
commonly adjusted parameters include:
-
`max_new_tokens`
: the maximum number of tokens to generate. In other words, the size of the output sequence, not
-
`max_new_tokens`
: the maximum number of tokens to generate. In other words, the size of the output sequence, not
including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose
including the tokens in the prompt. As an alternative to using the output's length as a stopping criteria, you can choose
to stop generation whenever the full generation exceeds some amount of time. To learn more, check [
`StoppingCriteria`
].
to stop generation whenever the full generation exceeds some amount of time. To learn more, check [
`StoppingCriteria`
].
-
`num_beams`
: by specifying a number of beams higher than 1, you are effectively switching from greedy search to
-
`num_beams`
: by specifying a number of beams higher than 1, you are effectively switching from greedy search to
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that
beam search. This strategy evaluates several hypotheses at each time step and eventually chooses the hypothesis that
...
@@ -339,13 +339,16 @@ This guide illustrates the main parameters that enable various decoding strategi
...
@@ -339,13 +339,16 @@ This guide illustrates the main parameters that enable various decoding strategi
[
`generate`
] method, which gives you even further control over the [
`generate`
] method's behavior.
[
`generate`
] method, which gives you even further control over the [
`generate`
] method's behavior.
For the complete list of the available parameters, refer to the
[
API documentation
](
./main_classes/text_generation.md
)
.
For the complete list of the available parameters, refer to the
[
API documentation
](
./main_classes/text_generation.md
)
.
###
Assisted
Decoding
###
Speculative
Decoding
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
Speculative decoding (also known as assisted decoding) is a modification of the decoding strategies above, that uses an
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
assistant model (ideally a much smaller one) with the same tokenizer, to generate a few candidate tokens. The main
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
model then validates the candidate tokens in a single forward pass, which speeds up the decoding process. If
and sampling are supported with assisted decoding, and doesn't support batched inputs. To learn more about assisted
`do_sample=True`
, then the token validation with resampling introduced in the
decoding, check
[
this blog post
](
https://huggingface.co/blog/assisted-generation
)
.
[
speculative decoding paper
](
https://arxiv.org/pdf/2211.17192.pdf
)
is used.
Currently, only greedy search and sampling are supported with assisted decoding, and assisted decoding doesn't support batched inputs.
To learn more about assisted decoding, check
[
this blog post
](
https://huggingface.co/blog/assisted-generation
)
.
To enable assisted decoding, set the
`assistant_model`
argument with a model.
To enable assisted decoding, set the
`assistant_model`
argument with a model.
...
@@ -366,8 +369,8 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
...
@@ -366,8 +369,8 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
[
'Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'
]
[
'Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'
]
```
```
When using assisted decoding with sampling methods, you can use the
`temperature`
argument to control the randomness
When using assisted decoding with sampling methods, you can use the
`temperature`
argument to control the randomness
,
just like in multinomial sampling. However, in assisted decoding, reducing the temperature
will
help improv
ing
latency.
just like in multinomial sampling. However, in assisted decoding, reducing the temperature
may
help improv
e the
latency.
```
python
```
python
>>>
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
set_seed
>>>
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
set_seed
...
...
src/transformers/generation/candidate_generator.py
View file @
45b70384
...
@@ -14,14 +14,14 @@
...
@@ -14,14 +14,14 @@
# limitations under the License.
# limitations under the License.
import
copy
import
copy
import
warnings
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
..modeling_utils
import
PreTrainedModel
from
..modeling_utils
import
PreTrainedModel
from
.configuration_utils
import
GenerationConfig
from
.logits_process
import
LogitsProcessorList
from
.logits_process
import
LogitsProcessorList
...
@@ -66,14 +66,17 @@ class CandidateGenerator:
...
@@ -66,14 +66,17 @@ class CandidateGenerator:
class
AssistedCandidateGenerator
(
CandidateGenerator
):
class
AssistedCandidateGenerator
(
CandidateGenerator
):
"""
"""
`CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of
`CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation
candidates through the use of a smaller model. Read the following blog post for more information:
https://huggingface.co/blog/assisted-generation
Args:
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
assistant_model (`PreTrainedModel`):
assistant_model (`PreTrainedModel`):
The model to be used for generating candidates. This model should be smaller than the main model.
The model to be used for generating candidates. This model should be smaller than the main model.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call.
logits_processor (`LogitsProcessorList`):
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
used to modify the prediction scores of the language modeling head applied at each generation step.
...
@@ -82,31 +85,20 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -82,31 +85,20 @@ class AssistedCandidateGenerator(CandidateGenerator):
model as well.
model as well.
inputs_tensor (`torch.Tensor`, *optional*):
inputs_tensor (`torch.Tensor`, *optional*):
The model input tensor. In encoder-decoder models, this is the encoder input.
The model input tensor. In encoder-decoder models, this is the encoder input.
eos_token_id (`Union[int, List[int]]`, *optional*):
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
assistant_model
:
"PreTrainedModel"
,
assistant_model
:
"PreTrainedModel"
,
generation_config
:
"GenerationConfig"
,
logits_processor
:
"LogitsProcessorList"
,
logits_processor
:
"LogitsProcessorList"
,
model_kwargs
:
Dict
,
model_kwargs
:
Dict
,
inputs_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_tensor
:
Optional
[
torch
.
Tensor
]
=
None
,
eos_token_id
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
):
):
# Prepare the assistant and the starting number of candidate tokens
self
.
assistant_model
=
assistant_model
self
.
assistant_model
=
assistant_model
self
.
num_assistant_tokens
=
assistant_model
.
generation_config
.
num_assistant_tokens
# Prepare the number of candidate tokens
if
hasattr
(
assistant_model
,
"num_assistant_tokens"
):
warnings
.
warn
(
"Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be "
"removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead."
,
FutureWarning
,
)
self
.
num_assistant_tokens
=
assistant_model
.
num_assistant_tokens
else
:
self
.
num_assistant_tokens
=
assistant_model
.
generation_config
.
num_assistant_tokens
# Prepare the kwargs for the assistant model
# Prepare the kwargs for the assistant model
assistant_kwargs
=
{}
assistant_kwargs
=
{}
...
@@ -145,13 +137,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -145,13 +137,17 @@ class AssistedCandidateGenerator(CandidateGenerator):
self
.
input_ids_key
=
"input_ids"
self
.
input_ids_key
=
"input_ids"
self
.
attention_key
=
"attention_mask"
self
.
attention_key
=
"attention_mask"
# Prepare other attributes
# Prepare generation-related options.
eos_token_id
=
generation_config
.
eos_token_id
if
isinstance
(
eos_token_id
,
int
):
if
isinstance
(
eos_token_id
,
int
):
eos_token_id
=
[
eos_token_id
]
eos_token_id
=
[
eos_token_id
]
self
.
eos_token_id_tensor
=
(
self
.
eos_token_id_tensor
=
(
torch
.
tensor
(
eos_token_id
).
to
(
input_ids
.
device
)
if
eos_token_id
is
not
None
else
None
torch
.
tensor
(
eos_token_id
).
to
(
input_ids
.
device
)
if
eos_token_id
is
not
None
else
None
)
)
self
.
logits_processor
=
logits_processor
self
.
logits_processor
=
logits_processor
self
.
generation_config
=
copy
.
deepcopy
(
generation_config
)
self
.
generation_config
.
return_dict_in_generate
=
True
self
.
generation_config
.
output_scores
=
True
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]:
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]:
"""
"""
...
@@ -185,12 +181,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -185,12 +181,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
# 2. Forecast next N tokens using the assistant model.
# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs
=
{
assistant_generation_kwargs
=
{
self
.
input_ids_key
:
input_ids
,
self
.
input_ids_key
:
input_ids
,
"do_sample"
:
False
,
"num_beams"
:
1
,
"max_new_tokens"
:
int
(
self
.
num_assistant_tokens
),
"max_new_tokens"
:
int
(
self
.
num_assistant_tokens
),
"
return_dict_in_generate"
:
True
,
"
generation_config"
:
self
.
generation_config
,
"
output_scores"
:
True
,
"
logits_processor"
:
self
.
logits_processor
,
}
}
assistant_output
=
self
.
assistant_model
.
generate
(
**
assistant_generation_kwargs
,
**
self
.
assistant_kwargs
)
assistant_output
=
self
.
assistant_model
.
generate
(
**
assistant_generation_kwargs
,
**
self
.
assistant_kwargs
)
# 3. Update variables for the next round of candidate generation
# 3. Update variables for the next round of candidate generation
...
...
src/transformers/generation/utils.py
View file @
45b70384
...
@@ -911,10 +911,10 @@ class GenerationMixin:
...
@@ -911,10 +911,10 @@ class GenerationMixin:
candidate_generator
=
AssistedCandidateGenerator
(
candidate_generator
=
AssistedCandidateGenerator
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
assistant_model
=
assistant_model
,
assistant_model
=
assistant_model
,
generation_config
=
generation_config
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
inputs_tensor
=
inputs_tensor
,
inputs_tensor
=
inputs_tensor
,
eos_token_id
=
generation_config
.
eos_token_id
,
)
)
return
candidate_generator
return
candidate_generator
...
@@ -1673,7 +1673,7 @@ class GenerationMixin:
...
@@ -1673,7 +1673,7 @@ class GenerationMixin:
)
)
# 8. prepare distribution pre_processing samplers
# 8. prepare distribution pre_processing samplers
logits_processor
=
self
.
_get_logits_processor
(
prepared_
logits_processor
=
self
.
_get_logits_processor
(
generation_config
=
generation_config
,
generation_config
=
generation_config
,
input_ids_seq_length
=
input_ids_length
,
input_ids_seq_length
=
input_ids_length
,
encoder_input_ids
=
inputs_tensor
,
encoder_input_ids
=
inputs_tensor
,
...
@@ -1685,7 +1685,7 @@ class GenerationMixin:
...
@@ -1685,7 +1685,7 @@ class GenerationMixin:
)
)
# 9. prepare stopping criteria
# 9. prepare stopping criteria
stopping_criteria
=
self
.
_get_stopping_criteria
(
prepared_
stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
)
# 10. go into different generation modes
# 10. go into different generation modes
...
@@ -1715,9 +1715,9 @@ class GenerationMixin:
...
@@ -1715,9 +1715,9 @@ class GenerationMixin:
input_ids
,
input_ids
,
candidate_generator
=
candidate_generator
,
candidate_generator
=
candidate_generator
,
do_sample
=
generation_config
.
do_sample
,
do_sample
=
generation_config
.
do_sample
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
,
logits_warper
=
self
.
_get_logits_warper
(
generation_config
)
if
generation_config
.
do_sample
else
None
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -1730,8 +1730,8 @@ class GenerationMixin:
...
@@ -1730,8 +1730,8 @@ class GenerationMixin:
# 11. run greedy search
# 11. run greedy search
return
self
.
greedy_search
(
return
self
.
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -1749,8 +1749,8 @@ class GenerationMixin:
...
@@ -1749,8 +1749,8 @@ class GenerationMixin:
input_ids
,
input_ids
,
top_k
=
generation_config
.
top_k
,
top_k
=
generation_config
.
top_k
,
penalty_alpha
=
generation_config
.
penalty_alpha
,
penalty_alpha
=
generation_config
.
penalty_alpha
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -1776,9 +1776,9 @@ class GenerationMixin:
...
@@ -1776,9 +1776,9 @@ class GenerationMixin:
# 13. run sample
# 13. run sample
return
self
.
sample
(
return
self
.
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -1810,8 +1810,8 @@ class GenerationMixin:
...
@@ -1810,8 +1810,8 @@ class GenerationMixin:
return
self
.
beam_search
(
return
self
.
beam_search
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -1847,9 +1847,9 @@ class GenerationMixin:
...
@@ -1847,9 +1847,9 @@ class GenerationMixin:
return
self
.
beam_sample
(
return
self
.
beam_sample
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -1881,8 +1881,8 @@ class GenerationMixin:
...
@@ -1881,8 +1881,8 @@ class GenerationMixin:
return
self
.
group_beam_search
(
return
self
.
group_beam_search
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -1954,8 +1954,8 @@ class GenerationMixin:
...
@@ -1954,8 +1954,8 @@ class GenerationMixin:
return
self
.
constrained_beam_search
(
return
self
.
constrained_beam_search
(
input_ids
,
input_ids
,
constrained_beam_scorer
=
constrained_beam_scorer
,
constrained_beam_scorer
=
constrained_beam_scorer
,
logits_processor
=
logits_processor
,
logits_processor
=
prepared_
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
prepared_
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
output_scores
=
generation_config
.
output_scores
,
...
@@ -4629,7 +4629,7 @@ class GenerationMixin:
...
@@ -4629,7 +4629,7 @@ class GenerationMixin:
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
max_matches
=
max_len
-
cur_len
-
1
max_matches
=
max_len
-
cur_len
-
1
if
do_sample
and
candidate_logits
is
not
None
:
if
do_sample
and
candidate_logits
is
not
None
:
next_sample
d_tokens
,
n_matches
=
_speculative_sampling
(
vali
d_tokens
,
n_matches
=
_speculative_sampling
(
candidate_input_ids
,
candidate_input_ids
,
candidate_logits
,
candidate_logits
,
candidate_length
,
candidate_length
,
...
@@ -4637,8 +4637,6 @@ class GenerationMixin:
...
@@ -4637,8 +4637,6 @@ class GenerationMixin:
last_assistant_token_is_eos
,
last_assistant_token_is_eos
,
max_matches
,
max_matches
,
)
)
# The selected tokens include the matches plus the next sampled tokens
selected_tokens
=
torch
.
cat
((
candidate_input_ids
[:,
:
n_matches
],
next_sampled_tokens
),
dim
=-
1
)
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
...
@@ -4657,6 +4655,7 @@ class GenerationMixin:
...
@@ -4657,6 +4655,7 @@ class GenerationMixin:
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
if
last_assistant_token_is_eos
and
n_matches
==
candidate_length
:
n_matches
-=
1
n_matches
-=
1
n_matches
=
min
(
n_matches
,
max_matches
)
n_matches
=
min
(
n_matches
,
max_matches
)
valid_tokens
=
selected_tokens
[:,
:
n_matches
+
1
]
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
...
@@ -4664,7 +4663,6 @@ class GenerationMixin:
...
@@ -4664,7 +4663,6 @@ class GenerationMixin:
# is no match.
# is no match.
# 4.1. Get the valid continuation, after the matching tokens
# 4.1. Get the valid continuation, after the matching tokens
valid_tokens
=
selected_tokens
[:,
:
n_matches
+
1
]
input_ids
=
torch
.
cat
((
input_ids
,
valid_tokens
),
dim
=-
1
)
input_ids
=
torch
.
cat
((
input_ids
,
valid_tokens
),
dim
=-
1
)
if
streamer
is
not
None
:
if
streamer
is
not
None
:
streamer
.
put
(
valid_tokens
.
cpu
())
streamer
.
put
(
valid_tokens
.
cpu
())
...
@@ -4782,24 +4780,16 @@ def _speculative_sampling(
...
@@ -4782,24 +4780,16 @@ def _speculative_sampling(
):
):
"""
"""
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
the
next
selected token, as well as the number of candidate matches.
the selected token
s
, as well as the number of candidate matches.
NOTE: Unless otherwise stated, the variable names match those in the paper.
NOTE: Unless otherwise stated, the variable names match those in the paper.
"""
"""
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
# selected by the assistant, respectively.
# selected by the assistant, respectively.
q
=
candidate_logits
.
softmax
(
dim
=-
1
)
q
=
candidate_logits
.
softmax
(
dim
=-
1
)
q_i
=
q
[
q_i
=
q
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]].
squeeze
(
0
,
1
)
:,
torch
.
range
(
0
,
candidate_length
-
1
,
dtype
=
torch
.
int
),
candidate_input_ids
[:,
-
candidate_length
:],
].
squeeze
(
0
,
1
)
p
=
new_logits
.
softmax
(
dim
=-
1
)
p
=
new_logits
.
softmax
(
dim
=-
1
)
p_i
=
p
[
p_i
=
p
[:,
torch
.
arange
(
candidate_length
),
candidate_input_ids
[:,
-
candidate_length
:]].
squeeze
(
0
,
1
)
:,
torch
.
range
(
0
,
candidate_length
-
1
,
dtype
=
torch
.
int
),
candidate_input_ids
[:,
-
candidate_length
:],
].
squeeze
(
0
,
1
)
probability_ratio
=
p_i
/
q_i
probability_ratio
=
p_i
/
q_i
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
# When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
...
@@ -4824,7 +4814,13 @@ def _speculative_sampling(
...
@@ -4824,7 +4814,13 @@ def _speculative_sampling(
p_prime
=
p_n_plus_1
p_prime
=
p_n_plus_1
t
=
torch
.
multinomial
(
p_prime
,
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
t
=
torch
.
multinomial
(
p_prime
,
num_samples
=
1
).
squeeze
(
1
)[
None
,
:]
return
t
,
n_matches
# The selected tokens include the matches (if any) plus the next sampled tokens
if
n_matches
>
0
:
valid_tokens
=
torch
.
cat
((
candidate_input_ids
[:,
-
n_matches
:],
t
),
dim
=-
1
)
else
:
valid_tokens
=
t
return
valid_tokens
,
n_matches
def
_split_model_outputs
(
outputs
,
new_outputs
,
cur_len
,
added_len
,
is_decoder_attention
=
False
):
def
_split_model_outputs
(
outputs
,
new_outputs
,
cur_len
,
added_len
,
is_decoder_attention
=
False
):
...
...
src/transformers/models/whisper/modeling_whisper.py
View file @
45b70384
...
@@ -2037,15 +2037,15 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
...
@@ -2037,15 +2037,15 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
FutureWarning
,
FutureWarning
,
)
)
if
generation_config
is
None
:
generation_config
=
copy
.
deepcopy
(
self
.
generation_config
)
return_dict_in_generate
=
(
return_dict_in_generate
=
(
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
if
return_dict_in_generate
is
not
None
else
self
.
generation_config
.
return_dict_in_generate
else
generation_config
.
return_dict_in_generate
)
)
if
generation_config
is
None
:
generation_config
=
copy
.
deepcopy
(
self
.
generation_config
)
input_stride
=
self
.
model
.
encoder
.
conv1
.
stride
[
0
]
*
self
.
model
.
encoder
.
conv2
.
stride
[
0
]
input_stride
=
self
.
model
.
encoder
.
conv1
.
stride
[
0
]
*
self
.
model
.
encoder
.
conv2
.
stride
[
0
]
if
num_segment_frames
is
None
:
if
num_segment_frames
is
None
:
num_segment_frames
=
input_stride
*
self
.
config
.
max_source_positions
num_segment_frames
=
input_stride
*
self
.
config
.
max_source_positions
...
...
tests/models/mistral/test_modeling_mistral.py
View file @
45b70384
...
@@ -21,7 +21,7 @@ import unittest
...
@@ -21,7 +21,7 @@ import unittest
import
pytest
import
pytest
from
transformers
import
AutoTokenizer
,
MistralConfig
,
is_torch_available
from
transformers
import
AutoTokenizer
,
MistralConfig
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
backend_empty_cache
,
backend_empty_cache
,
require_bitsandbytes
,
require_bitsandbytes
,
...
@@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
...
@@ -527,3 +527,27 @@ class MistralIntegrationTest(unittest.TestCase):
del
model
del
model
backend_empty_cache
(
torch_device
)
backend_empty_cache
(
torch_device
)
gc
.
collect
()
gc
.
collect
()
@
slow
def
test_speculative_generation
(
self
):
EXPECTED_TEXT_COMPLETION
=
(
"My favourite condiment is 100% Sriracha. I love the heat, the tang and the fact costs"
)
prompt
=
"My favourite condiment is "
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"mistralai/Mistral-7B-v0.1"
,
use_fast
=
False
)
model
=
MistralForCausalLM
.
from_pretrained
(
"mistralai/Mistral-7B-v0.1"
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
)
input_ids
=
tokenizer
.
encode
(
prompt
,
return_tensors
=
"pt"
).
to
(
model
.
model
.
embed_tokens
.
weight
.
device
)
# greedy generation outputs
set_seed
(
0
)
generated_ids
=
model
.
generate
(
input_ids
,
max_new_tokens
=
20
,
do_sample
=
True
,
temperature
=
0.3
,
assistant_model
=
model
)
text
=
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
)
self
.
assertEqual
(
EXPECTED_TEXT_COMPLETION
,
text
)
del
model
backend_empty_cache
(
torch_device
)
gc
.
collect
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment