"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7a7ee28cb9b930018e9ca49d18c0445fb14badd6"
Unverified Commit 7e0ddf89 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: consolidate output classes (#28494)

parent 72db39c0
...@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") ...@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
``` ```
The `generation_output` object is a [`~generation.GreedySearchDecoderOnlyOutput`], as we can The `generation_output` object is a [`~generation.GenerateDecoderOnlyOutput`], as we can
see in the documentation of that class below, it means it has the following attributes: see in the documentation of that class below, it means it has the following attributes:
- `sequences`: the generated sequences of tokens - `sequences`: the generated sequences of tokens
...@@ -77,25 +77,13 @@ We document here all output types. ...@@ -77,25 +77,13 @@ We document here all output types.
### PyTorch ### PyTorch
[[autodoc]] generation.GreedySearchEncoderDecoderOutput [[autodoc]] generation.GenerateDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput [[autodoc]] generation.GenerateEncoderDecoderOutput
[[autodoc]] generation.SampleEncoderDecoderOutput [[autodoc]] generation.GenerateBeamDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput [[autodoc]] generation.GenerateBeamEncoderDecoderOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
### TensorFlow ### TensorFlow
......
...@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") ...@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
``` ```
`generation_output` オブジェクトは、できる限り [`~generation.GreedySearchDecoderOnlyOutput`] です。 `generation_output` オブジェクトは、できる限り [`~generation.GenerateDecoderOnlyOutput`] です。
以下のそのクラスのドキュメントを参照してください。これは、次の属性があることを意味します。 以下のそのクラスのドキュメントを参照してください。これは、次の属性があることを意味します。
- `sequences`: 生成されたトークンのシーケンス - `sequences`: 生成されたトークンのシーケンス
...@@ -76,25 +76,13 @@ generation_output[:2] ...@@ -76,25 +76,13 @@ generation_output[:2]
### PyTorch ### PyTorch
[[autodoc]] generation.GreedySearchEncoderDecoderOutput [[autodoc]] generation.GenerateDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput [[autodoc]] generation.GenerateEncoderDecoderOutput
[[autodoc]] generation.SampleEncoderDecoderOutput [[autodoc]] generation.GenerateBeamDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput [[autodoc]] generation.GenerateBeamEncoderDecoderOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
### TensorFlow ### TensorFlow
......
...@@ -43,7 +43,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt") ...@@ -43,7 +43,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True) generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)
``` ```
`generation_output` 的对象是 [`~generation.GreedySearchDecoderOnlyOutput`] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性: `generation_output` 的对象是 [`~generation.GenerateDecoderOnlyOutput`] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
- `sequences`: 生成的tokens序列 - `sequences`: 生成的tokens序列
- `scores`(可选): 每个生成步骤的语言建模头的预测分数 - `scores`(可选): 每个生成步骤的语言建模头的预测分数
...@@ -70,25 +70,13 @@ generation_output[:2] ...@@ -70,25 +70,13 @@ generation_output[:2]
### PyTorch ### PyTorch
[[autodoc]] generation.GreedySearchEncoderDecoderOutput [[autodoc]] generation.GenerateDecoderOnlyOutput
[[autodoc]] generation.GreedySearchDecoderOnlyOutput [[autodoc]] generation.GenerateEncoderDecoderOutput
[[autodoc]] generation.SampleEncoderDecoderOutput [[autodoc]] generation.GenerateBeamDecoderOnlyOutput
[[autodoc]] generation.SampleDecoderOnlyOutput [[autodoc]] generation.GenerateBeamEncoderDecoderOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
### TensorFlow ### TensorFlow
......
...@@ -94,6 +94,10 @@ else: ...@@ -94,6 +94,10 @@ else:
"BeamSampleDecoderOnlyOutput", "BeamSampleDecoderOnlyOutput",
"ContrastiveSearchEncoderDecoderOutput", "ContrastiveSearchEncoderDecoderOutput",
"ContrastiveSearchDecoderOnlyOutput", "ContrastiveSearchDecoderOnlyOutput",
"GenerateBeamDecoderOnlyOutput",
"GenerateBeamEncoderDecoderOutput",
"GenerateDecoderOnlyOutput",
"GenerateEncoderDecoderOutput",
] ]
try: try:
...@@ -222,6 +226,10 @@ if TYPE_CHECKING: ...@@ -222,6 +226,10 @@ if TYPE_CHECKING:
BeamSearchEncoderDecoderOutput, BeamSearchEncoderDecoderOutput,
ContrastiveSearchDecoderOnlyOutput, ContrastiveSearchDecoderOnlyOutput,
ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchEncoderDecoderOutput,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GenerationMixin, GenerationMixin,
GreedySearchDecoderOnlyOutput, GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput,
......
This diff is collapsed.
...@@ -1197,18 +1197,14 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1197,18 +1197,14 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`], - [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`], - [`~generation.GenerateBeamDecoderOnlyOutput`]
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if generation_config is None: if generation_config is None:
...@@ -2244,18 +2240,14 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2244,18 +2240,14 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`], - [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`], - [`~generation.GenerateBeamDecoderOnlyOutput`]
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects # 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if generation_config is None: if generation_config is None:
......
...@@ -1264,10 +1264,8 @@ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel): ...@@ -1264,10 +1264,8 @@ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel):
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
if generation_config is None: if generation_config is None:
......
...@@ -2845,11 +2845,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel): ...@@ -2845,11 +2845,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
# prepare text_decoder_input_ids # prepare text_decoder_input_ids
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
...@@ -3134,11 +3131,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel): ...@@ -3134,11 +3131,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
......
...@@ -3110,11 +3110,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel): ...@@ -3110,11 +3110,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
# prepare text_decoder_input_ids # prepare text_decoder_input_ids
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
...@@ -3409,11 +3406,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel): ...@@ -3409,11 +3406,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are: [`~utils.ModelOutput`] types are:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
""" """
text_decoder_input_ids = kwargs.pop("decoder_input_ids", None) text_decoder_input_ids = kwargs.pop("decoder_input_ids", None)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids. # overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
......
...@@ -1968,10 +1968,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1968,10 +1968,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are: else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`], - [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`], - [`~generation.GenerateBeamEncoderDecoderOutput`]
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
else only the generated output sequence ids are returned. else only the generated output sequence ids are returned.
......
...@@ -65,6 +65,10 @@ if is_torch_available(): ...@@ -65,6 +65,10 @@ if is_torch_available():
DisjunctiveConstraint, DisjunctiveConstraint,
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput, GreedySearchDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
...@@ -730,9 +734,15 @@ class GenerationTesterMixin: ...@@ -730,9 +734,15 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
...@@ -848,9 +858,15 @@ class GenerationTesterMixin: ...@@ -848,9 +858,15 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) self.assertIsInstance(output_sample, SampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
...@@ -952,9 +968,15 @@ class GenerationTesterMixin: ...@@ -952,9 +968,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
...@@ -1109,9 +1131,15 @@ class GenerationTesterMixin: ...@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_sample, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
...@@ -1238,9 +1266,15 @@ class GenerationTesterMixin: ...@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_group_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_group_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_group_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_group_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
...@@ -1390,9 +1424,15 @@ class GenerationTesterMixin: ...@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_search, GenerateBeamEncoderDecoderOutput)
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_beam_search, BeamSearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_search, GenerateBeamDecoderOnlyOutput)
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
# Retrocompatibility check
self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_beam_search, BeamSearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
......
...@@ -53,12 +53,10 @@ if is_torch_available(): ...@@ -53,12 +53,10 @@ if is_torch_available():
set_seed, set_seed,
) )
from transformers.generation import ( from transformers.generation import (
GreedySearchDecoderOnlyOutput, GenerateDecoderOnlyOutput,
GreedySearchEncoderDecoderOutput, GenerateEncoderDecoderOutput,
InfNanRemoveLogitsProcessor, InfNanRemoveLogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
SampleDecoderOnlyOutput,
SampleEncoderDecoderOutput,
) )
...@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
...@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform # override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former # additional post-processing in the former
...@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) self.assertIsInstance(output_sample, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self): def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
...@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_greedy, GenerateDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
...@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
...@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_sample_generate(self): def test_sample_generate(self):
for model_class in self.greedy_sample_model_classes: for model_class in self.greedy_sample_model_classes:
...@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_sample, SampleEncoderDecoderOutput) self.assertIsInstance(output_sample, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
config, _, _, _, max_length = self._get_input_ids_and_config() config, _, _, _, max_length = self._get_input_ids_and_config()
...@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_greedy, GenerateEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate) self.assertNotIn(config.pad_token_id, output_generate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment