Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
87a0783d
"...resnet50_tensorflow.git" did not exist on "f047d65958f0b07f9b178eabbbcb70a3cc5374b8"
Unverified
Commit
87a0783d
authored
Mar 05, 2024
by
Joao Gante
Committed by
GitHub
Mar 05, 2024
Browse files
Generate: inner decoding methods are no longer public (#29437)
parent
4d892b72
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
117 additions
and
104 deletions
+117
-104
docs/source/en/generation_strategies.md
docs/source/en/generation_strategies.md
+3
-0
docs/source/en/internal/generation_utils.md
docs/source/en/internal/generation_utils.md
+2
-11
docs/source/en/main_classes/text_generation.md
docs/source/en/main_classes/text_generation.md
+0
-7
docs/source/ja/internal/generation_utils.md
docs/source/ja/internal/generation_utils.md
+0
-9
docs/source/ja/main_classes/text_generation.md
docs/source/ja/main_classes/text_generation.md
+0
-7
docs/source/zh/internal/generation_utils.md
docs/source/zh/internal/generation_utils.md
+1
-10
docs/source/zh/main_classes/text_generation.md
docs/source/zh/main_classes/text_generation.md
+0
-7
src/transformers/generation/configuration_utils.py
src/transformers/generation/configuration_utils.py
+9
-9
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+96
-38
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+4
-4
src/transformers/models/rag/modeling_rag.py
src/transformers/models/rag/modeling_rag.py
+2
-2
No files found.
docs/source/en/generation_strategies.md
View file @
87a0783d
...
@@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
...
@@ -389,3 +389,6 @@ just like in multinomial sampling. However, in assisted decoding, reducing the t
>>>
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
>>>
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
[
'Alice and Bob are going to the same party. It is a small party, in a small'
]
[
'Alice and Bob are going to the same party. It is a small party, in a small'
]
```
```
Alternativelly, you can also set the
`prompt_lookup_num_tokens`
to trigger n-gram based assisted decoding, as opposed
to model based assisted decoding. You can read more about it
[
here
](
https://twitter.com/joao_gante/status/1747322413006643259
)
.
docs/source/en/internal/generation_utils.md
View file @
87a0783d
...
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
...
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# Utilities for Generation
# Utilities for Generation
This page lists all the utility functions used by [
`~generation.GenerationMixin.generate`
],
This page lists all the utility functions used by [
`~generation.GenerationMixin.generate`
].
[
`~generation.GenerationMixin.greedy_search`
],
[
`~generation.GenerationMixin.contrastive_search`
],
[
`~generation.GenerationMixin.sample`
],
[
`~generation.GenerationMixin.beam_search`
],
[
`~generation.GenerationMixin.beam_sample`
],
[
`~generation.GenerationMixin.group_beam_search`
], and
[
`~generation.GenerationMixin.constrained_beam_search`
].
Most of those are only useful if you are studying the code of the generate methods in the library.
## Generate Outputs
## Generate Outputs
...
@@ -376,4 +367,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
...
@@ -376,4 +367,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
[[autodoc]] StaticCache
-
update
-
update
-
get_seq_length
-
get_seq_length
\ No newline at end of file
docs/source/en/main_classes/text_generation.md
View file @
87a0783d
...
@@ -43,13 +43,6 @@ like token streaming.
...
@@ -43,13 +43,6 @@ like token streaming.
[[autodoc]] generation.GenerationMixin
[[autodoc]] generation.GenerationMixin
-
generate
-
generate
-
compute_transition_scores
-
compute_transition_scores
-
greedy_search
-
sample
-
beam_search
-
beam_sample
-
contrastive_search
-
group_beam_search
-
constrained_beam_search
## TFGenerationMixin
## TFGenerationMixin
...
...
docs/source/ja/internal/generation_utils.md
View file @
87a0783d
...
@@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
...
@@ -17,15 +17,6 @@ rendered properly in your Markdown viewer.
# 発電用ユーティリティ
# 発電用ユーティリティ
このページには、[
`~generation.GenerationMixin.generate`
] で使用されるすべてのユーティリティ関数がリストされています。
このページには、[
`~generation.GenerationMixin.generate`
] で使用されるすべてのユーティリティ関数がリストされています。
[
`~generation.GenerationMixin.greedy_search`
],
[
`~generation.GenerationMixin.contrastive_search`
],
[
`~generation.GenerationMixin.sample`
],
[
`~generation.GenerationMixin.beam_search`
],
[
`~generation.GenerationMixin.beam_sample`
],
[
`~generation.GenerationMixin.group_beam_search`
]、および
[
`~generation.GenerationMixin.constrained_beam_search`
]。
これらのほとんどは、ライブラリ内の生成メソッドのコードを学習する場合にのみ役に立ちます。
## 出力を生成する
## 出力を生成する
...
...
docs/source/ja/main_classes/text_generation.md
View file @
87a0783d
...
@@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
...
@@ -43,13 +43,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin
[[autodoc]] generation.GenerationMixin
-
generate
-
generate
-
compute_transition_scores
-
compute_transition_scores
-
greedy_search
-
sample
-
beam_search
-
beam_sample
-
contrastive_search
-
group_beam_search
-
constrained_beam_search
## TFGenerationMixin
## TFGenerationMixin
...
...
docs/source/zh/internal/generation_utils.md
View file @
87a0783d
...
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
...
@@ -16,16 +16,7 @@ rendered properly in your Markdown viewer.
# 用于生成的工具
# 用于生成的工具
此页面列出了所有由 [
`~generation.GenerationMixin.generate`
],
此页面列出了所有由 [
`~generation.GenerationMixin.generate`
]。
[
`~generation.GenerationMixin.greedy_search`
],
[
`~generation.GenerationMixin.contrastive_search`
],
[
`~generation.GenerationMixin.sample`
],
[
`~generation.GenerationMixin.beam_search`
],
[
`~generation.GenerationMixin.beam_sample`
],
[
`~generation.GenerationMixin.group_beam_search`
], 和
[
`~generation.GenerationMixin.constrained_beam_search`
]使用的实用函数。
其中大多数仅在您研究库中生成方法的代码时才有用。
## 生成输出
## 生成输出
...
...
docs/source/zh/main_classes/text_generation.md
View file @
87a0783d
...
@@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
...
@@ -38,13 +38,6 @@ rendered properly in your Markdown viewer.
[[autodoc]] generation.GenerationMixin
[[autodoc]] generation.GenerationMixin
-
generate
-
generate
-
compute_transition_scores
-
compute_transition_scores
-
greedy_search
-
sample
-
beam_search
-
beam_sample
-
contrastive_search
-
group_beam_search
-
constrained_beam_search
## TFGenerationMixin
## TFGenerationMixin
...
...
src/transformers/generation/configuration_utils.py
View file @
87a0783d
...
@@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
...
@@ -43,22 +43,22 @@ class GenerationConfig(PushToHubMixin):
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
- *greedy decoding* by calling [`~generation.GenerationMixin.
_
greedy_search`] if `num_beams=1` and
`do_sample=False`
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0.`
- *contrastive search* by calling [`~generation.GenerationMixin.
_
contrastive_search`] if `penalty_alpha>0.`
and `top_k>1`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation.GenerationMixin.
_
sample`] if `num_beams=1` and
`do_sample=True`
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search decoding* by calling [`~generation.GenerationMixin.
_
beam_search`] if `num_beams>1` and
`do_sample=False`
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.
_
beam_sample`] if
`num_beams>1` and `do_sample=True`
`num_beams>1` and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.
_
group_beam_search`], if
`num_beams>1` and `num_beam_groups>1`
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.
_
constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
- *assisted decoding* by calling [`~generation.GenerationMixin.
_
assisted_decoding`], if
`assistant_model` is passed to `.generate()`
`assistant_model`
or `prompt_lookup_num_tokens`
is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
...
...
src/transformers/generation/utils.py
View file @
87a0783d
...
@@ -347,20 +347,22 @@ class GenerationMixin:
...
@@ -347,20 +347,22 @@ class GenerationMixin:
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
- *greedy decoding* by calling [`~generation.GenerationMixin.
_
greedy_search`] if `num_beams=1` and
`do_sample=False`
`do_sample=False`
- *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and
- *contrastive search* by calling [`~generation.GenerationMixin.
_
contrastive_search`] if `penalty_alpha>0` and
`top_k>1`
`top_k>1`
- *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
- *multinomial sampling* by calling [`~generation.GenerationMixin.
_
sample`] if `num_beams=1` and
`do_sample=True`
`do_sample=True`
- *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
- *beam-search decoding* by calling [`~generation.GenerationMixin.
_
beam_search`] if `num_beams>1` and
`do_sample=False`
`do_sample=False`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1`
- *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.
_
beam_sample`] if `num_beams>1`
and `do_sample=True`
and `do_sample=True`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1`
- *diverse beam-search decoding* by calling [`~generation.GenerationMixin.
_
group_beam_search`], if `num_beams>1`
and `num_beam_groups>1`
and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.
_
constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin._assisted_decoding`], if
`assistant_model` or `prompt_lookup_num_tokens` is passed to `.generate()`
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
...
@@ -1547,7 +1549,7 @@ class GenerationMixin:
...
@@ -1547,7 +1549,7 @@ class GenerationMixin:
)
)
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
if
generation_mode
==
GenerationMode
.
GREEDY_SEARCH
:
# 11. run greedy search
# 11. run greedy search
result
=
self
.
greedy_search
(
result
=
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
stopping_criteria
=
prepared_stopping_criteria
,
stopping_criteria
=
prepared_stopping_criteria
,
...
@@ -1565,7 +1567,7 @@ class GenerationMixin:
...
@@ -1565,7 +1567,7 @@ class GenerationMixin:
if
not
model_kwargs
[
"use_cache"
]:
if
not
model_kwargs
[
"use_cache"
]:
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
result
=
self
.
contrastive_search
(
result
=
self
.
_
contrastive_search
(
input_ids
,
input_ids
,
top_k
=
generation_config
.
top_k
,
top_k
=
generation_config
.
top_k
,
penalty_alpha
=
generation_config
.
penalty_alpha
,
penalty_alpha
=
generation_config
.
penalty_alpha
,
...
@@ -1595,7 +1597,7 @@ class GenerationMixin:
...
@@ -1595,7 +1597,7 @@ class GenerationMixin:
)
)
# 13. run sample
# 13. run sample
result
=
self
.
sample
(
result
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
...
@@ -1629,7 +1631,7 @@ class GenerationMixin:
...
@@ -1629,7 +1631,7 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
# 13. run beam search
# 13. run beam search
result
=
self
.
beam_search
(
result
=
self
.
_
beam_search
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
...
@@ -1668,7 +1670,7 @@ class GenerationMixin:
...
@@ -1668,7 +1670,7 @@ class GenerationMixin:
)
)
# 14. run beam sample
# 14. run beam sample
result
=
self
.
beam_sample
(
result
=
self
.
_
beam_sample
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
...
@@ -1703,7 +1705,7 @@ class GenerationMixin:
...
@@ -1703,7 +1705,7 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
# 13. run beam search
# 13. run beam search
result
=
self
.
group_beam_search
(
result
=
self
.
_
group_beam_search
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
...
@@ -1777,7 +1779,7 @@ class GenerationMixin:
...
@@ -1777,7 +1779,7 @@ class GenerationMixin:
**
model_kwargs
,
**
model_kwargs
,
)
)
# 13. run beam search
# 13. run beam search
result
=
self
.
constrained_beam_search
(
result
=
self
.
_
constrained_beam_search
(
input_ids
,
input_ids
,
constrained_beam_scorer
=
constrained_beam_scorer
,
constrained_beam_scorer
=
constrained_beam_scorer
,
logits_processor
=
prepared_logits_processor
,
logits_processor
=
prepared_logits_processor
,
...
@@ -1801,8 +1803,15 @@ class GenerationMixin:
...
@@ -1801,8 +1803,15 @@ class GenerationMixin:
return
result
return
result
def
contrastive_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_contrastive_search
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
contrastive_search
(
def
_
contrastive_search
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
top_k
:
Optional
[
int
]
=
1
,
top_k
:
Optional
[
int
]
=
1
,
...
@@ -1828,7 +1837,7 @@ class GenerationMixin:
...
@@ -1828,7 +1837,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.contrastive_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_
contrastive_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -1902,7 +1911,7 @@ class GenerationMixin:
...
@@ -1902,7 +1911,7 @@ class GenerationMixin:
>>> input_prompt = "DeepMind Company is"
>>> input_prompt = "DeepMind Company is"
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> input_ids = tokenizer(input_prompt, return_tensors="pt")
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=64)])
>>> outputs = model.contrastive_search(
>>> outputs = model.
_
contrastive_search(
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
... **input_ids, penalty_alpha=0.6, top_k=4, stopping_criteria=stopping_criteria
... )
... )
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
...
@@ -2243,7 +2252,14 @@ class GenerationMixin:
...
@@ -2243,7 +2252,14 @@ class GenerationMixin:
else
:
else
:
return
input_ids
return
input_ids
def
greedy_search
(
def
greedy_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `greedy_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_greedy_search
(
*
args
,
**
kwargs
)
def
_greedy_search
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
...
@@ -2266,7 +2282,7 @@ class GenerationMixin:
...
@@ -2266,7 +2282,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin.
_
greedy_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -2348,7 +2364,7 @@ class GenerationMixin:
...
@@ -2348,7 +2364,7 @@ class GenerationMixin:
... )
... )
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> outputs = model.greedy_search(
>>> outputs = model.
_
greedy_search(
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
... )
... )
...
@@ -2514,7 +2530,14 @@ class GenerationMixin:
...
@@ -2514,7 +2530,14 @@ class GenerationMixin:
else
:
else
:
return
input_ids
return
input_ids
def
sample
(
def
sample
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_sample
(
*
args
,
**
kwargs
)
def
_sample
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
logits_processor
:
Optional
[
LogitsProcessorList
]
=
None
,
...
@@ -2538,7 +2561,7 @@ class GenerationMixin:
...
@@ -2538,7 +2561,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
In most cases, you do not need to call [`~generation.GenerationMixin.
_
sample`] directly. Use generate() instead.
For an overview of generation strategies and code examples, check the [following
For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -2635,7 +2658,7 @@ class GenerationMixin:
...
@@ -2635,7 +2658,7 @@ class GenerationMixin:
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
>>> outputs = model.sample(
>>> outputs = model.
_
sample(
... input_ids,
... input_ids,
... logits_processor=logits_processor,
... logits_processor=logits_processor,
... logits_warper=logits_warper,
... logits_warper=logits_warper,
...
@@ -2832,7 +2855,14 @@ class GenerationMixin:
...
@@ -2832,7 +2855,14 @@ class GenerationMixin:
past_key_values
.
reorder_cache
(
beam_idx
)
past_key_values
.
reorder_cache
(
beam_idx
)
return
past_key_values
return
past_key_values
def
beam_search
(
def
beam_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_beam_search
(
*
args
,
**
kwargs
)
def
_beam_search
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
beam_scorer
:
BeamScorer
,
beam_scorer
:
BeamScorer
,
...
@@ -2856,7 +2886,7 @@ class GenerationMixin:
...
@@ -2856,7 +2886,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin.
_
beam_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -2958,7 +2988,7 @@ class GenerationMixin:
...
@@ -2958,7 +2988,7 @@ class GenerationMixin:
... ]
... ]
... )
... )
>>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> outputs = model.
_
beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Wie alt bist du?']
['Wie alt bist du?']
...
@@ -3214,7 +3244,14 @@ class GenerationMixin:
...
@@ -3214,7 +3244,14 @@ class GenerationMixin:
else
:
else
:
return
sequence_outputs
[
"sequences"
]
return
sequence_outputs
[
"sequences"
]
def
beam_sample
(
def
beam_sample
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `beam_sample` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_beam_sample
(
*
args
,
**
kwargs
)
def
_beam_sample
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
beam_scorer
:
BeamScorer
,
beam_scorer
:
BeamScorer
,
...
@@ -3238,7 +3275,7 @@ class GenerationMixin:
...
@@ -3238,7 +3275,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.beam_sample`] directly. Use generate()
In most cases, you do not need to call [`~generation.GenerationMixin.
_
beam_sample`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -3346,7 +3383,7 @@ class GenerationMixin:
...
@@ -3346,7 +3383,7 @@ class GenerationMixin:
... ]
... ]
... )
... )
>>> outputs = model.beam_sample(
>>> outputs = model.
_
beam_sample(
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... input_ids, beam_scorer, logits_processor=logits_processor, logits_warper=logits_warper, **model_kwargs
... )
... )
...
@@ -3561,7 +3598,14 @@ class GenerationMixin:
...
@@ -3561,7 +3598,14 @@ class GenerationMixin:
else
:
else
:
return
sequence_outputs
[
"sequences"
]
return
sequence_outputs
[
"sequences"
]
def
group_beam_search
(
def
group_beam_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `group_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_group_beam_search
(
*
args
,
**
kwargs
)
def
_group_beam_search
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
beam_scorer
:
BeamScorer
,
beam_scorer
:
BeamScorer
,
...
@@ -3584,7 +3628,7 @@ class GenerationMixin:
...
@@ -3584,7 +3628,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.group_beam_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_
group_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -3686,7 +3730,7 @@ class GenerationMixin:
...
@@ -3686,7 +3730,7 @@ class GenerationMixin:
... ]
... ]
... )
... )
>>> outputs = model.group_beam_search(
>>> outputs = model.
_
group_beam_search(
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
... )
... )
...
@@ -3958,7 +4002,14 @@ class GenerationMixin:
...
@@ -3958,7 +4002,14 @@ class GenerationMixin:
else
:
else
:
return
sequence_outputs
[
"sequences"
]
return
sequence_outputs
[
"sequences"
]
def
constrained_beam_search
(
def
constrained_beam_search
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `constrained_beam_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_constrained_beam_search
(
*
args
,
**
kwargs
)
def
_constrained_beam_search
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
constrained_beam_scorer
:
ConstrainedBeamSearchScorer
,
constrained_beam_scorer
:
ConstrainedBeamSearchScorer
,
...
@@ -3981,7 +4032,7 @@ class GenerationMixin:
...
@@ -3981,7 +4032,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.constrained_beam_search`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_
constrained_beam_search`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -4088,7 +4139,7 @@ class GenerationMixin:
...
@@ -4088,7 +4139,7 @@ class GenerationMixin:
... ]
... ]
... )
... )
>>> outputs = model.constrained_beam_search(
>>> outputs = model.
_
constrained_beam_search(
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... input_ids, beam_scorer, constraints=constraints, logits_processor=logits_processor, **model_kwargs
... )
... )
...
@@ -4311,7 +4362,14 @@ class GenerationMixin:
...
@@ -4311,7 +4362,14 @@ class GenerationMixin:
else
:
else
:
return
sequence_outputs
[
"sequences"
]
return
sequence_outputs
[
"sequences"
]
def
assisted_decoding
(
def
assisted_decoding
(
self
,
*
args
,
**
kwargs
):
logger
.
warning_once
(
"Calling `_assisted_decoding` directly is deprecated and will be removed in v4.41. Use `generate` or a "
"custom generation loop instead."
,
)
return
self
.
_assisted_decoding
(
*
args
,
**
kwargs
)
def
_assisted_decoding
(
self
,
self
,
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
candidate_generator
:
Optional
[
"CandidateGenerator"
]
=
None
,
candidate_generator
:
Optional
[
"CandidateGenerator"
]
=
None
,
...
@@ -4338,7 +4396,7 @@ class GenerationMixin:
...
@@ -4338,7 +4396,7 @@ class GenerationMixin:
<Tip warning={true}>
<Tip warning={true}>
In most cases, you do not need to call [`~generation.GenerationMixin.
candida
te_decoding`] directly. Use
In most cases, you do not need to call [`~generation.GenerationMixin.
_assis
te
d
_decoding`] directly. Use
generate() instead. For an overview of generation strategies and code examples, check the [following
generate() instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
guide](../generation_strategies).
...
@@ -4429,7 +4487,7 @@ class GenerationMixin:
...
@@ -4429,7 +4487,7 @@ class GenerationMixin:
... logits_processor=logits_processor,
... logits_processor=logits_processor,
... model_kwargs={},
... model_kwargs={},
... )
... )
>>> outputs = model.assisted_decoding(
>>> outputs = model.
_
assisted_decoding(
... input_ids,
... input_ids,
... candidate_generator=candidate_generator,
... candidate_generator=candidate_generator,
... logits_processor=logits_processor,
... logits_processor=logits_processor,
...
...
src/transformers/models/musicgen/modeling_musicgen.py
View file @
87a0783d
...
@@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1336,7 +1336,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
...
@@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1361,7 +1361,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
)
)
# 12. run sample
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
...
@@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2402,7 +2402,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
)
# 11. run greedy search
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
...
@@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2428,7 +2428,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
)
)
# 12. run sample
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
...
...
src/transformers/models/rag/modeling_rag.py
View file @
87a0783d
...
@@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1539,7 +1539,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
" greedy search."
" greedy search."
)
)
return
self
.
greedy_search
(
return
self
.
_
greedy_search
(
input_ids
,
input_ids
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
max_length
=
generation_config
.
max_length
,
max_length
=
generation_config
.
max_length
,
...
@@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
...
@@ -1559,7 +1559,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
num_beam_hyps_to_keep
=
generation_config
.
num_return_sequences
,
max_length
=
generation_config
.
max_length
,
max_length
=
generation_config
.
max_length
,
)
)
return
self
.
beam_search
(
return
self
.
_
beam_search
(
input_ids
,
input_ids
,
beam_scorer
,
beam_scorer
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment