Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d57ffb48
Unverified
Commit
d57ffb48
authored
May 01, 2024
by
Joao Gante
Committed by
GitHub
May 01, 2024
Browse files
Generate: remove deprecated public decoding functions and streamline logic 🧼 (#29956)
parent
dc401d3a
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
237 additions
and
1295 deletions
+237
-1295
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+1
-1
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+10
-19
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+193
-1215
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+10
-26
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
...ormers/models/musicgen_melody/modeling_musicgen_melody.py
+12
-28
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+11
-6
No files found.
src/transformers/generation/candidate_generator.py
View file @
d57ffb48
...
@@ -123,7 +123,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
...
@@ -123,7 +123,7 @@ class AssistedCandidateGenerator(CandidateGenerator):
inputs_tensor
,
assistant_model
.
generation_config
.
bos_token_id
,
assistant_kwargs
inputs_tensor
,
assistant_model
.
generation_config
.
bos_token_id
,
assistant_kwargs
)
)
assistant_kwargs
=
assistant_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
assistant_kwargs
=
assistant_model
.
_prepare_encoder_decoder_kwargs_for_generation
(
inputs_tensor
,
assistant_kwargs
,
model_input_name
inputs_tensor
,
assistant_kwargs
,
model_input_name
,
assistant_model
.
generation_config
)
)
elif
"encoder_outputs"
in
model_kwargs
:
elif
"encoder_outputs"
in
model_kwargs
:
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"encoder_outputs"
]
assistant_kwargs
[
"encoder_outputs"
]
=
model_kwargs
[
"encoder_outputs"
]
...
...
src/transformers/generation/configuration_utils.py
View file @
d57ffb48
...
@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin):
...
@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and
- *greedy decoding* if `num_beams=1` and `do_sample=False`
`do_sample=False`
- *contrastive search* if `penalty_alpha>0.` and `top_k>1`
- *contrastive search* by calling [`~generation.GenerationMixin._contrastive_search`] if `penalty_alpha>0.`
- *multinomial sampling* if `num_beams=1` and `do_sample=True`
and `top_k>1`
- *beam-search decoding* if `num_beams>1` and `do_sample=False`
- *multinomial sampling* by calling [`~generation.GenerationMixin._sample`] if `num_beams=1` and
- *beam-search multinomial sampling* if `num_beams>1` and `do_sample=True`
`do_sample=True`
- *diverse beam-search decoding* if `num_beams>1` and `num_beam_groups>1`
- *beam-search decoding* by calling [`~generation.GenerationMixin._beam_search`] if `num_beams>1` and
- *constrained beam-search decoding* if `constraints!=None` or `force_words_ids!=None`
`do_sample=False`
- *assisted decoding* if `assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin._beam_sample`] if
`num_beams>1` and `do_sample=True`
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin._group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin._constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
<Tip>
<Tip>
...
...
src/transformers/generation/utils.py
View file @
d57ffb48
This diff is collapsed.
Click to expand it.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
d57ffb48
...
@@ -1650,8 +1650,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1650,8 +1650,6 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -1748,10 +1746,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1748,10 +1746,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -1774,10 +1769,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1774,10 +1769,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2423,8 +2415,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2423,8 +2415,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
self
,
self
,
inputs_tensor
:
torch
.
Tensor
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
,
model_input_name
:
Optional
[
str
],
g
uidance_scale
:
Optional
[
float
]
=
None
,
g
eneration_config
:
GenerationConfig
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
# 1. get text encoder
# 1. get text encoder
encoder
=
self
.
get_text_encoder
()
encoder
=
self
.
get_text_encoder
()
...
@@ -2446,6 +2438,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2446,6 +2438,9 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
}
encoder_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
encoder_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
guidance_scale
=
generation_config
.
guidance_scale
# 3. make sure that encoder returns `ModelOutput`
# 3. make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
...
@@ -2708,8 +2703,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2708,8 +2703,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
batch_size
=
inputs_tensor
.
shape
[
0
]
batch_size
=
inputs_tensor
.
shape
[
0
]
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -2723,10 +2716,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2723,10 +2716,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
model_kwargs
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
)
)
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
...
@@ -2831,10 +2821,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2831,10 +2821,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2858,10 +2845,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2858,10 +2845,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
...
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
View file @
d57ffb48
...
@@ -1586,8 +1586,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1586,8 +1586,6 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -1684,10 +1682,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1684,10 +1682,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -1710,10 +1705,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1710,10 +1705,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2318,12 +2310,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2318,12 +2310,13 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
self
,
self
,
inputs_tensor
:
torch
.
Tensor
,
inputs_tensor
:
torch
.
Tensor
,
model_kwargs
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
,
model_input_name
:
Optional
[
str
],
g
uidance_scale
:
Optional
[
float
]
=
None
,
g
eneration_config
:
GenerationConfig
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
encoder_hidden_states
=
None
encoder_hidden_states
=
None
# attention mask is consumed once to produce text conditional hidden states through the text encoder
# attention mask is consumed once to produce text conditional hidden states through the text encoder
encoder_attention_mask
=
model_kwargs
.
pop
(
"attention_mask"
)
encoder_attention_mask
=
model_kwargs
.
pop
(
"attention_mask"
)
guidance_scale
=
generation_config
.
guidance_scale
# 1. condition on text
# 1. condition on text
if
inputs_tensor
is
not
None
:
if
inputs_tensor
is
not
None
:
...
@@ -2346,6 +2339,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2346,6 +2339,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
encoder_kwargs
=
{
encoder_kwargs
=
{
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
argument
:
value
for
argument
,
value
in
encoder_kwargs
.
items
()
if
argument
in
encoder_signature
}
}
encoder_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
encoder_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
# make sure that encoder returns `ModelOutput`
# make sure that encoder returns `ModelOutput`
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
model_input_name
=
model_input_name
if
model_input_name
is
not
None
else
self
.
text_encoder
.
main_input_name
...
@@ -2572,8 +2567,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2572,8 +2567,6 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
batch_size
=
inputs_tensor
.
shape
[
0
]
batch_size
=
inputs_tensor
.
shape
[
0
]
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
model_kwargs
[
"guidance_scale"
]
=
generation_config
.
guidance_scale
...
@@ -2585,10 +2578,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2585,10 +2578,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
if
"encoder_hidden_states"
not
in
model_kwargs
:
if
"encoder_hidden_states"
not
in
model_kwargs
:
# encoder_hidden_states are created and added to `model_kwargs`
# encoder_hidden_states are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_hidden_states_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_encoder_hidden_states_kwargs_for_generation
(
inputs_tensor
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
model_kwargs
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
)
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
# 5. Prepare `input_ids` which will be used for auto-regressive generation
...
@@ -2684,14 +2674,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2684,14 +2674,11 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
@@ -2710,15 +2697,12 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2710,15 +2697,12 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
)
)
# 12. run sample
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
synced_gpus
=
synced_gpus
,
synced_gpus
=
synced_gpus
,
streamer
=
streamer
,
streamer
=
streamer
,
**
model_kwargs
,
**
model_kwargs
,
...
...
src/transformers/models/rag/modeling_rag.py
View file @
d57ffb48
...
@@ -1537,6 +1537,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1537,6 +1537,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
)
)
prepared_stopping_criteria
=
self
.
_get_stopping_criteria
(
generation_config
=
generation_config
,
stopping_criteria
=
stopping_criteria
)
if
generation_config
.
num_beams
==
1
:
if
generation_config
.
num_beams
==
1
:
if
generation_config
.
num_return_sequences
>
1
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1546,9 +1550,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1546,9 +1550,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
return
self
.
_greedy_search
(
return
self
.
_greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
stopping_criteria
=
prepared_stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
synced_gpus
=
False
,
streamer
=
None
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
generation_config
.
num_beams
>
1
:
elif
generation_config
.
num_beams
>
1
:
...
@@ -1567,9 +1572,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1567,9 +1572,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
stopping_criteria
=
prepared_stopping_criteria
,
pad_token_id
=
generation_config
.
pad_token_id
,
generation_config
=
generation_config
,
eos_token_id
=
generation_config
.
eos_token_id
,
synced_gpus
=
False
,
**
model_kwargs
,
**
model_kwargs
,
)
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment