Unverified Commit 7e0ddf89 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: consolidate output classes (#28494)

parent 72db39c0
......@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
```
The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can
The `generation_output` object is a [`~generation.GenerateDecoderOnlyOutput`], as we can
see in the documentation of that class below, it means it has the following attributes:
- `sequences`: the generated sequences of tokens
......@@ -77,25 +77,13 @@ We document here all output types.
### PyTorch
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation.GenerateDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation.GenerateEncoderDecoderOutput
[[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation.GenerateBeamDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
......
......@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
```
`generation_output` オブジェクトは、できる限り [`~generation.GreedySearchDecoderOnlyOutput`] です。
`generation_output` オブジェクトは、できる限り [`~generation.GenerateDecoderOnlyOutput`] です。
以下のそのクラスのドキュメントを参照してください。これは、次の属性があることを意味します。
- `sequences`: 生成されたトークンのシーケンス
......@@ -76,25 +76,13 @@ generation_output[:2]
### PyTorch
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation.GenerateDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation.GenerateEncoderDecoderOutput
[[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation.GenerateBeamDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
......
......@@ -43,7 +43,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
```
`generation_output` 的对象是 [`~generation.GreedySearchDecoderOnlyOutput`] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
`generation_output` 的对象是 [`~generation.GenerateDecoderOnlyOutput`] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
- `sequences`: 生成的tokens序列
- `scores`(可选): 每个生成步骤的语言建模头的预测分数
......@@ -70,25 +70,13 @@ generation_output[:2]
### PyTorch
[[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation.GenerateDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation.GenerateEncoderDecoderOutput
[[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation.GenerateBeamDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
......
......@@ -94,6 +94,10 @@ else:
"BeamSampleDecoderOnlyOutput",
"ContrastiveSearchEncoderDecoderOutput",
"ContrastiveSearchDecoderOnlyOutput",
"GenerateBeamDecoderOnlyOutput",
"GenerateBeamEncoderDecoderOutput",
"GenerateDecoderOnlyOutput",
"GenerateEncoderDecoderOutput",
]
try:
......@@ -222,6 +226,10 @@ if TYPE_CHECKING:
BeamSearchEncoderDecoderOutput,
ContrastiveSearchDecoderOnlyOutput,
ContrastiveSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GenerationMixin,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
......
This diff is collapsed.
......@@ -1197,18 +1197,14 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
- [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.GenerateBeamDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if generation_config is None:
......@@ -2244,18 +2240,14 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
- [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.GenerateBeamDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if generation_config is None:
......
......@@ -1264,10 +1264,8 @@ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel):
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
if generation_config is None:
......
......@@ -2845,11 +2845,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# prepare text_decoder_input_ids
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
......@@ -3134,11 +3131,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
......
......@@ -3110,11 +3110,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# prepare text_decoder_input_ids
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
......@@ -3409,11 +3406,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
......
......@@ -1968,10 +1968,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
else only the generated output sequence ids are returned.
......
......@@ -65,6 +65,10 @@ if is_torch_available():
DisjunctiveConstraint,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
HammingDiversityLogitsProcessor,
......@@ -730,9 +734,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
......@@ -848,9 +858,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
......@@ -952,9 +968,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
......@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
......@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True,
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
......@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
)
if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
......
......@@ -53,12 +53,10 @@ if is_torch_available():
set_seed,
)
from transformers.generation import (
GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
)
......@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
......@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
......@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
......@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
......@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
......@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes:
......@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self):
config, _, _, _, max_length = self._get_input_ids_and_config()
......@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment