Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d57ffb48
Unverified
Commit
d57ffb48
authored
May 01, 2024
by
Joao Gante
Committed by
GitHub
May 01, 2024
Browse files
Generate: remove deprecated public decoding functions and streamline logic 🧼 (#29956)
parent
dc401d3a
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
237 additions
and
1295 deletions
+237
-1295
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+1
-1
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+10
-19
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+193
-1215
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+10
-26
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
...ormers/models/musicgen_melody/modeling_musicgen_melody.py
+12
-28
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+11
-6
No files found.
src/transformers/generation/candidate_generator.py
View file @
d57ffb48
...
@@ -123,7 +123,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -123,7 +123,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
inputs_tensor
,
assistant_model
.
generation_config
.
bos_token_id
,
assistant_kwargs
inputs_tensor
,
assistant_model
.
generation_config
.
bos_token_id
,
assistant_kwargs
)
)
assistant_kwargs
=
assistant_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
assistant_kwargs
=
assistant_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
inputs_tensor
,
assistant_kwargs
,
model_input_name
inputs_tensor
,
assistant_kwargs
,
model_input_name
,
assistant_model
.
generation_config
)
)
elif
"encoder_outputs"
in
model_kwargs
:
elif
"encoder_outputs"
in
model_kwargs
:
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"encoder_outputs"
]
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"encoder_outputs"
]
...
...
src/transformers/generation/configuration_utils.py
View file @
d57ffb48
...
@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin):
...
@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
- *greedy decoding* if `num_beams=1` and `do_sample=False`
`do_sample=False`
- *contrastive search* if `penalty_alpha>0.` and `top_k>1`
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
and `top_k>1`
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
`do_sample=True`
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
`do_sample=False`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
`num_beams>1` and `do_sample=True`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
<Tip>
<Tip>
...
...
src/transformers/generation/utils.py
View file @
d57ffb48
This diff is collapsed.
Click to expand it.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
d57ffb48
...
@@ -1650,8 +1650,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1650,8 +1650,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -1748,10 +1746,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1748,10 +1746,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -1774,10 +1769,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1774,10 +1769,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2423,8 +2415,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2423,8 +2415,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
self
,
self
,
inputs_tensor
:
torch
.
Tensor
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
,
model_input_name
:
Optional
[
str
],
g
uidance_scale
:
Optional
[
float
]
=
None
,
g
eneration_config
:
GenerationConfig
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
# 1. get text encoder
# 1. get text encoder
encoder
=
self
.
get_text_encoder
()
encoder
=
self
.
get_text_encoder
()
...
@@ -2446,6 +2438,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2446,6 +2438,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
}
encoder_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
encoder_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
guidance_scale
=
generation_config
.
guidance_scale
# 3. make sure that encoder returns `ModelOutput`
# 3. make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
...
@@ -2708,8 +2703,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2708,8 +2703,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
batch_size
=
inputs_tensor
.
shape
[
0
]
batch_size
=
inputs_tensor
.
shape
[
0
]
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -2723,10 +2716,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2723,10 +2716,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
model_kwargs
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
)
)
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
...
@@ -2831,10 +2821,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2831,10 +2821,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2858,10 +2845,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2858,10 +2845,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
...
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
View file @
d57ffb48
...
@@ -1586,8 +1586,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1586,8 +1586,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -1684,10 +1682,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1684,10 +1682,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -1710,10 +1705,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1710,10 +1705,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2318,12 +2310,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2318,12 +2310,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
self
,
self
,
inputs_tensor
:
torch
.
Tensor
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
,
model_input_name
:
Optional
[
str
],
g
uidance_scale
:
Optional
[
float
]
=
None
,
g
eneration_config
:
GenerationConfig
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
encoder_hidden_states
=
None
encoder_hidden_states
=
None
# attention mask is consumed once to produce text conditional hidden states through the text encoder
# attention mask is consumed once to produce text conditional hidden states through the text encoder
encoder_attention_mask
=
model_kwargs
.
pop
(
"attention_mask"
)
encoder_attention_mask
=
model_kwargs
.
pop
(
"attention_mask"
)
guidance_scale
=
generation_config
.
guidance_scale
# 1. condition on text
# 1. condition on text
if
inputs_tensor
is
not
None
:
if
inputs_tensor
is
not
None
:
...
@@ -2346,6 +2339,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2346,6 +2339,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
}
encoder_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
encoder_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
# make sure that encoder returns `ModelOutput`
# make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
...
@@ -2572,8 +2567,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2572,8 +2567,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
batch_size
=
inputs_tensor
.
shape
[
0
]
batch_size
=
inputs_tensor
.
shape
[
0
]
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -2585,10 +2578,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2585,10 +2578,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
if
"encoder_hidden_states"
not
in
model_kwargs
:
if
"encoder_hidden_states"
not
in
model_kwargs
:
# encoder_hidden_states are created and added to `model_kwargs`
# encoder_hidden_states are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_hidden_states_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_encoder_hidden_states_kwargs_for_generation
(
inputs_tensor
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
model_kwargs
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
)
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
# 5. Prepare `input_ids` which will be used for auto-regressive generation
...
@@ -2684,14 +2674,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2684,14 +2674,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2710,15 +2697,12 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2710,15 +2697,12 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
)
# 12. run sample
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
...
src/transformers/models/rag/modeling_rag.py
View file @
d57ffb48
...
@@ -1537,6 +1537,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1537,6 +1537,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
)
)
prepared_stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
if
generation_config
.
num_beams
==
1
:
if
generation_config
.
num_beams
==
1
:
if
generation_config
.
num_return_sequences
>
1
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1546,9 +1550,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1546,9 +1550,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
return
self
.
_greedy_search
(
return
self
.
_greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
stopping_criteria
=
prepared_stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
synced_gpus
=
False
,
streamer
=
None
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
generation_config
.
num_beams
>
1
:
elif
generation_config
.
num_beams
>
1
:
...
@@ -1567,9 +1572,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1567,9 +1572,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
stopping_criteria
=
prepared_stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
synced_gpus
=
False
,
**
model_kwargs
,
**
model_kwargs
,
)
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment