Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
87a0783d
"docs/source/vscode:/vscode.git/clone" did not exist on "73dc23f7866d2f5e48ea60c9ab0753811ed261b6"
Unverified
Commit
87a0783d
authored
Mar 05, 2024
by
Joao Gante
Committed by
GitHub
Mar 05, 2024
Browse files
Generate: inner decoding methods are no longer public (#29437)
parent
4d892b72
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
117 additions
and
104 deletions
+117
-104
docs/source/en/generation_strategies.md
docs/source/en/generation_strategies.md
+3
-0
docs/source/en/internal/generation_utils.md
docs/source/en/internal/generation_utils.md
+2
-11
docs/source/en/main_classes/text_generation.md
docs/source/en/main_classes/text_generation.md
+0
-7
docs/source/ja/internal/generation_utils.md
docs/source/ja/internal/generation_utils.md
+0
-9
docs/source/ja/main_classes/text_generation.md
docs/source/ja/main_classes/text_generation.md
+0
-7
docs/source/zh/internal/generation_utils.md
docs/source/zh/internal/generation_utils.md
+1
-10
docs/source/zh/main_classes/text_generation.md
docs/source/zh/main_classes/text_generation.md
+0
-7
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+9
-9
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+96
-38
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+4
-4
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+2
-2
No files found.
docs/source/en/generation_strategies.md
View file @
87a0783d
...
...
@@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
>>>
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
[
'Alice and Bob are going to the same party. It is a small party, in a small'
]
```
Alternativelly, you can also set the
`prompt_lookup_num_tokens`
to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it
[
here
](
https://twitter.com/joao_gante/status/1747322413006643259
)
.
docs/source/en/internal/generation_utils.md
View file @
87a0783d
...
...
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# Utilities for Generation
This page lists all the utility functions used by [
`~generation.GenerationMixin.generate`
],
[
`~generation.GenerationMixin.greedy_search`
],
[
`~generation.GenerationMixin.contrastive_search`
],
[
`~generation.GenerationMixin.sample`
],
[
`~generation.GenerationMixin.beam_search`
],
[
`~generation.GenerationMixin.beam_sample`
],
[
`~generation.GenerationMixin.group_beam_search`
], and
[
`~generation.GenerationMixin.constrained_beam_search`
].
Most of those are only useful if you are studying the code of the generate methods in the library.
This page lists all the utility functions used by [
`~generation.GenerationMixin.generate`
].
## Generate Outputs
...
...
@@ -376,4 +367,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
-
update
-
get_seq_length
\ No newline at end of file
-
get_seq_length
docs/source/en/main_classes/text_generation.md
View file @
87a0783d
...
...
@@ -43,13 +43,6 @@ like token streaming.
[[autodoc]] generation.GenerationMixin
-
generate
-
compute_transition_scores
-
greedy_search
-
sample
-
beam_search
-
beam_sample
-
contrastive_search
-
group_beam_search
-
constrained_beam_search
## TFGenerationMixin
...
...
docs/source/ja/internal/generation_utils.md
View file @
87a0783d
...
...
@@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
# 発電用ユーティリティ
このページには、[
`~generation.GenerationMixin.generate`
] で使用されるすべてのユーティリティ関数がリストされています。
[
`~generation.GenerationMixin.greedy_search`
],
[
`~generation.GenerationMixin.contrastive_search`
],
[
`~generation.GenerationMixin.sample`
],
[
`~generation.GenerationMixin.beam_search`
],
[
`~generation.GenerationMixin.beam_sample`
],
[
`~generation.GenerationMixin.group_beam_search`
]、および
[
`~generation.GenerationMixin.constrained_beam_search`
]。
これらのほとんどは、ライブラリ内の生成メソッドのコードを学習する場合にのみ役に立ちます。
## 出力を生成する
...
...
docs/source/ja/main_classes/text_generation.md
View file @
87a0783d
...
...
@@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin
-
generate
-
compute_transition_scores
-
greedy_search
-
sample
-
beam_search
-
beam_sample
-
contrastive_search
-
group_beam_search
-
constrained_beam_search
## TFGenerationMixin
...
...
docs/source/zh/internal/generation_utils.md
View file @
87a0783d
...
...
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# 用于生成的工具
此页面列出了所有由 [
`~generation.GenerationMixin.generate`
],
[
`~generation.GenerationMixin.greedy_search`
],
[
`~generation.GenerationMixin.contrastive_search`
],
[
`~generation.GenerationMixin.sample`
],
[
`~generation.GenerationMixin.beam_search`
],
[
`~generation.GenerationMixin.beam_sample`
],
[
`~generation.GenerationMixin.group_beam_search`
], 和
[
`~generation.GenerationMixin.constrained_beam_search`
]使用的实用函数。
其中大多数仅在您研究库中生成方法的代码时才有用。
此页面列出了所有由 [
`~generation.GenerationMixin.generate`
]。
## 生成输出
...
...
docs/source/zh/main_classes/text_generation.md
View file @
87a0783d
...
...
@@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin
-
generate
-
compute_transition_scores
-
greedy_search
-
sample
-
beam_search
-
beam_sample
-
contrastive_search
-
group_beam_search
-
constrained_beam_search
## TFGenerationMixin
...
...
src/transformers/generation/configuration_utils.py
View file @
87a0783d
...
...
@@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
- *greedy decoding* by calling [`~generation.GenerationMixin.
_
greedy_search`] if `num_beams=1` and
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
- *contrastive search* by calling [`~generation.GenerationMixin.
_
contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation.GenerationMixin.
_
sample`] if `num_beams=1` and
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search decoding* by calling [`~generation.GenerationMixin.
_
beam_search`] if `num_beams>1` and
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.
_
beam_sample`] if
`num_beams>1` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.
_
group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.
_
constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
`assistant_model` is passed to `.generate()`
- *assisted decoding* by calling [`~generation.GenerationMixin.
_
assisted_decoding`], if
`assistant_model`
or `prompt_lookup_num_tokens`
is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
...
...
src/transformers/generation/utils.py
View file @
87a0783d
...
...
@@ -347,20 +347,22 @@ class GenerationMixin:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
- *greedy decoding* by calling [`~generation.GenerationMixin.
_
greedy_search`] if `num_beams=1` and
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and
- *contrastive search* by calling [`~generation.GenerationMixin.
_
contrastive_search`] if `penalty_alpha>0` and
`top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation.GenerationMixin.
_
sample`] if `num_beams=1` and
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search decoding* by calling [`~generation.GenerationMixin.
_
beam_search`] if `num_beams>1` and
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.
_
beam_sample`] if `num_beams>1`
and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.
_
group_beam_search`], if `num_beams>1`
and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.
_
constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
...
...
@@ -1547,7 +1549,7 @@ class GenerationMixin:
)
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
# 11. run greedy search
result
=
self
.
greedy_search
(
result
=
self
.
_
greedy_search
(
input_ids
,
logits_processor
=
prepared_logits_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
...
...
@@ -1565,7 +1567,7 @@ class GenerationMixin:
if
not
model_kwargs
[
"use_cache"
]:
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
result
=
self
.
contrastive_search
(
result
=
self
.
_
contrastive_search
(
input_ids
,
top_k
=
generation_config
.
top_k
,
penalty_alpha
=
generation_config
.
penalty_alpha
,
...
...
@@ -1595,7 +1597,7 @@ class GenerationMixin:
)
# 13. run sample
result
=
self
.
sample
(
result
=
self
.
_
sample
(
input_ids
,
logits_processor
=
prepared_logits_processor
,
logits_warper
=
logits_warper
,
...
...
@@ -1629,7 +1631,7 @@ class GenerationMixin:
**
model_kwargs
,
)
# 13. run beam search
result
=
self
.
beam_search
(
result
=
self
.
_
beam_search
(
input_ids
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
...
...
@@ -1668,7 +1670,7 @@ class GenerationMixin:
)
# 14. run beam sample
result
=
self
.
beam_sample
(
result
=
self
.
_
beam_sample
(
input_ids
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
...
...
@@ -1703,7 +1705,7 @@ class GenerationMixin:
**
model_kwargs
,
)
# 13. run beam search
result
=
self
.
group_beam_search
(
result
=
self
.
_
group_beam_search
(
input_ids
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
...
...
@@ -1777,7 +1779,7 @@ class GenerationMixin:
**
model_kwargs
,
)
# 13. run beam search
result
=
self
.
constrained_beam_search
(
result
=
self
.
_
constrained_beam_search
(
input_ids
,
constrained_beam_scorer
=
constrained_beam_scorer
,
logits_processor
=
prepared_logits_processor
,
...
...
@@ -1801,8 +1803,15 @@ class GenerationMixin:
return
result
def
contrastive_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_contrastive_search
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
def
contrastive_search
(
def
_
contrastive_search
(
self
,
input_ids
:
torch
.
LongTensor
,
top_k
:
Optional
[
int
]
=
1
,
...
...
@@ -1828,7 +1837,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_
contrastive_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -1902,7 +1911,7 @@ class GenerationMixin:
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
>>> outputs = model.contrastive_search(
>>> outputs = model.
_
contrastive_search(
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
...
...
@@ -2243,7 +2252,14 @@ class GenerationMixin:
else
:
return
input_ids
def
greedy_search
(
def
greedy_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_greedy_search
(
*
args
,
**
kwargs
)
def
_greedy_search
(
self
,
input_ids
:
torch
.
LongTensor
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
...
...
@@ -2266,7 +2282,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin.
_
greedy_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -2348,7 +2364,7 @@ class GenerationMixin:
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search(
>>> outputs = model.
_
greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... )
...
...
@@ -2514,7 +2530,14 @@ class GenerationMixin:
else
:
return
input_ids
def
sample
(
def
sample
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_sample
(
*
args
,
**
kwargs
)
def
_sample
(
self
,
input_ids
:
torch
.
LongTensor
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
...
...
@@ -2538,7 +2561,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
In most cases, you do not need to call [`~generation.GenerationMixin.
_
sample`] directly. Use generate() instead.
For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -2635,7 +2658,7 @@ class GenerationMixin:
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample(
>>> outputs = model.
_
sample(
... input_ids,
... logits_processor=logits_processor,
... logits_warper=logits_warper,
...
...
@@ -2832,7 +2855,14 @@ class GenerationMixin:
past_key_values
.
reorder_cache
(
beam_idx
)
return
past_key_values
def
beam_search
(
def
beam_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_beam_search
(
*
args
,
**
kwargs
)
def
_beam_search
(
self
,
input_ids
:
torch
.
LongTensor
,
beam_scorer
:
BeamScorer
,
...
...
@@ -2856,7 +2886,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin.
_
beam_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -2958,7 +2988,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> outputs = model.
_
beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
...
...
@@ -3214,7 +3244,14 @@ class GenerationMixin:
else
:
return
sequence_outputs
[
"sequences"
]
def
beam_sample
(
def
beam_sample
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_beam_sample
(
*
args
,
**
kwargs
)
def
_beam_sample
(
self
,
input_ids
:
torch
.
LongTensor
,
beam_scorer
:
BeamScorer
,
...
...
@@ -3238,7 +3275,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin.
_
beam_sample`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -3346,7 +3383,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.beam_sample(
>>> outputs = model.
_
beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
...
...
@@ -3561,7 +3598,14 @@ class GenerationMixin:
else
:
return
sequence_outputs
[
"sequences"
]
def
group_beam_search
(
def
group_beam_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_group_beam_search
(
*
args
,
**
kwargs
)
def
_group_beam_search
(
self
,
input_ids
:
torch
.
LongTensor
,
beam_scorer
:
BeamScorer
,
...
...
@@ -3584,7 +3628,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_
group_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -3686,7 +3730,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.group_beam_search(
>>> outputs = model.
_
group_beam_search(
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... )
...
...
@@ -3958,7 +4002,14 @@ class GenerationMixin:
else
:
return
sequence_outputs
[
"sequences"
]
def
constrained_beam_search
(
def
constrained_beam_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_constrained_beam_search
(
*
args
,
**
kwargs
)
def
_constrained_beam_search
(
self
,
input_ids
:
torch
.
LongTensor
,
constrained_beam_scorer
:
ConstrainedBeamSearchScorer
,
...
...
@@ -3981,7 +4032,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_
constrained_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -4088,7 +4139,7 @@ class GenerationMixin:
... ]
... )
>>> outputs = model.constrained_beam_search(
>>> outputs = model.
_
constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... )
...
...
@@ -4311,7 +4362,14 @@ class GenerationMixin:
else
:
return
sequence_outputs
[
"sequences"
]
def
assisted_decoding
(
def
assisted_decoding
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_assisted_decoding
(
*
args
,
**
kwargs
)
def
_assisted_decoding
(
self
,
input_ids
:
torch
.
LongTensor
,
candidate_generator
:
Optional
[
"CandidateGenerator"
]
=
None
,
...
...
@@ -4338,7 +4396,7 @@ class GenerationMixin:
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.
candida
te_decoding`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_assis
te
d
_decoding`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
...
...
@@ -4429,7 +4487,7 @@ class GenerationMixin:
... logits_processor=logits_processor,
... model_kwargs={},
... )
>>> outputs = model.assisted_decoding(
>>> outputs = model.
_
assisted_decoding(
... input_ids,
... candidate_generator=candidate_generator,
... logits_processor=logits_processor,
...
...
src/transformers/models/musicgen/modeling_musicgen.py
View file @
87a0783d
...
...
@@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
...
...
@@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
...
...
@@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
...
...
@@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
...
...
src/transformers/models/rag/modeling_rag.py
View file @
87a0783d
...
...
@@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
" greedy search."
)
return
self
.
greedy_search
(
return
self
.
_
greedy_search
(
input_ids
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
...
...
@@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
)
return
self
.
beam_search
(
return
self
.
_
beam_search
(
input_ids
,
beam_scorer
,
logits_processor
=
pre_processor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment