Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d57ffb48
Unverified
Commit
d57ffb48
authored
May 01, 2024
by
Joao Gante
Committed by
GitHub
May 01, 2024
Browse files
Generate: remove deprecated public decoding functions and streamline logic 🧼 (#29956)
parent
dc401d3a
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
237 additions
and
1295 deletions
+237
-1295
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+1
-1
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+10
-19
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+193
-1215
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+10
-26
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
...ormers/models/musicgen_melody/modeling_musicgen_melody.py
+12
-28
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+11
-6
No files found.
src/transformers/generation/candidate_generator.py
View file @
d57ffb48
...
...
@@ -123,7 +123,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
inputs_tensor
,
assistant_model
.
generation_config
.
bos_token_id
,
assistant_kwargs
)
assistant_kwargs
=
assistant_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
inputs_tensor
,
assistant_kwargs
,
model_input_name
inputs_tensor
,
assistant_kwargs
,
model_input_name
,
assistant_model
.
generation_config
)
elif
"encoder_outputs"
in
model_kwargs
:
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"encoder_outputs"
]
...
...
src/transformers/generation/configuration_utils.py
View file @
d57ffb48
...
...
@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
`num_beams>1` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
- *greedy decoding* if `num_beams=1` and `do_sample=False`
- *contrastive search* if `penalty_alpha>0.` and `top_k>1`
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
<Tip>
...
...
src/transformers/generation/utils.py
View file @
d57ffb48
This diff is collapsed.
Click to expand it.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
d57ffb48
...
...
@@ -1650,8 +1650,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
...
@@ -1748,10 +1746,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
@@ -1774,10 +1769,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
@@ -2423,8 +2415,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
self
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
,
g
uidance_scale
:
Optional
[
float
]
=
None
,
model_input_name
:
Optional
[
str
],
g
eneration_config
:
GenerationConfig
,
)
->
Dict
[
str
,
Any
]:
# 1. get text encoder
encoder
=
self
.
get_text_encoder
()
...
...
@@ -2446,6 +2438,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
encoder_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
encoder_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
guidance_scale
=
generation_config
.
guidance_scale
# 3. make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
...
...
@@ -2708,8 +2703,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
batch_size
=
inputs_tensor
.
shape
[
0
]
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
...
@@ -2723,10 +2716,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
)
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
...
...
@@ -2831,10 +2821,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
@@ -2858,10 +2845,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
View file @
d57ffb48
...
...
@@ -1586,8 +1586,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
...
@@ -1684,10 +1682,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
@@ -1710,10 +1705,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
@@ -2318,12 +2310,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
self
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
,
g
uidance_scale
:
Optional
[
float
]
=
None
,
model_input_name
:
Optional
[
str
],
g
eneration_config
:
GenerationConfig
,
)
->
Dict
[
str
,
Any
]:
encoder_hidden_states
=
None
# attention mask is consumed once to produce text conditional hidden states through the text encoder
encoder_attention_mask
=
model_kwargs
.
pop
(
"attention_mask"
)
guidance_scale
=
generation_config
.
guidance_scale
# 1. condition on text
if
inputs_tensor
is
not
None
:
...
...
@@ -2346,6 +2339,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
encoder_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
encoder_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
# make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
...
...
@@ -2572,8 +2567,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
batch_size
=
inputs_tensor
.
shape
[
0
]
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
...
@@ -2585,10 +2578,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
if
"encoder_hidden_states"
not
in
model_kwargs
:
# encoder_hidden_states are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_hidden_states_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
...
...
@@ -2684,14 +2674,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
@@ -2710,15 +2697,12 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
generation_config
=
generation_config
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
**
model_kwargs
,
...
...
src/transformers/models/rag/modeling_rag.py
View file @
d57ffb48
...
...
@@ -1537,6 +1537,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
logits_processor
=
logits_processor
,
)
prepared_stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
if
generation_config
.
num_beams
==
1
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
...
...
@@ -1546,9 +1550,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
return
self
.
_greedy_search
(
input_ids
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
False
,
streamer
=
None
,
**
model_kwargs
,
)
elif
generation_config
.
num_beams
>
1
:
...
...
@@ -1567,9 +1572,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
input_ids
,
beam_scorer
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
stopping_criteria
=
prepared_stopping_criteria
,
generation_config
=
generation_config
,
synced_gpus
=
False
,
**
model_kwargs
,
)
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment