Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7e0ddf89
Unverified
Commit
7e0ddf89
authored
Jan 15, 2024
by
Joao Gante
Committed by
GitHub
Jan 15, 2024
Browse files
Generate: consolidate output classes (#28494)
parent
72db39c0
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
176 additions
and
458 deletions
+176
-458
docs/source/en/internal/generation_utils.md
docs/source/en/internal/generation_utils.md
+5
-17
docs/source/ja/internal/generation_utils.md
docs/source/ja/internal/generation_utils.md
+5
-17
docs/source/zh/internal/generation_utils.md
docs/source/zh/internal/generation_utils.md
+5
-17
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+8
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+75
-343
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+8
-16
src/transformers/models/pop2piano/modeling_pop2piano.py
src/transformers/models/pop2piano/modeling_pop2piano.py
+2
-4
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
...transformers/models/seamless_m4t/modeling_seamless_m4t.py
+4
-10
src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
...ormers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
+4
-10
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+2
-4
tests/generation/test_utils.py
tests/generation/test_utils.py
+40
-0
tests/models/musicgen/test_modeling_musicgen.py
tests/models/musicgen/test_modeling_musicgen.py
+18
-20
No files found.
docs/source/en/internal/generation_utils.md
View file @
7e0ddf89
...
...
@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output
=
model
.
generate
(
**
inputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
```
The
`generation_output`
object is a [
`~generation.G
reedySearch
DecoderOnlyOutput`
], as we can
The
`generation_output`
object is a [
`~generation.G
enerate
DecoderOnlyOutput`
], as we can
see in the documentation of that class below, it means it has the following attributes:
-
`sequences`
: the generated sequences of tokens
...
...
@@ -77,25 +77,13 @@ We document here all output types.
### PyTorch
[[autodoc]] generation.G
reedySearchEncoder
DecoderOutput
[[autodoc]] generation.G
enerate
DecoderO
nlyO
utput
[[autodoc]] generation.G
reedySearch
DecoderO
nlyO
utput
[[autodoc]] generation.G
enerateEncoder
DecoderOutput
[[autodoc]] generation.
SampleEncoder
DecoderOutput
[[autodoc]] generation.
GenerateBeam
DecoderO
nlyO
utput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
...
...
docs/source/ja/internal/generation_utils.md
View file @
7e0ddf89
...
...
@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output
=
model
.
generate
(
**
inputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
```
`generation_output`
オブジェクトは、できる限り [
`~generation.G
reedySearch
DecoderOnlyOutput`
] です。
`generation_output`
オブジェクトは、できる限り [
`~generation.G
enerate
DecoderOnlyOutput`
] です。
以下のそのクラスのドキュメントを参照してください。これは、次の属性があることを意味します。
-
`sequences`
: 生成されたトークンのシーケンス
...
...
@@ -76,25 +76,13 @@ generation_output[:2]
### PyTorch
[[autodoc]] generation.G
reedySearchEncoder
DecoderOutput
[[autodoc]] generation.G
enerate
DecoderO
nlyO
utput
[[autodoc]] generation.G
reedySearch
DecoderO
nlyO
utput
[[autodoc]] generation.G
enerateEncoder
DecoderOutput
[[autodoc]] generation.
SampleEncoder
DecoderOutput
[[autodoc]] generation.
GenerateBeam
DecoderO
nlyO
utput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
...
...
docs/source/zh/internal/generation_utils.md
View file @
7e0ddf89
...
...
@@ -43,7 +43,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output
=
model
.
generate
(
**
inputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
```
`generation_output`
的对象是 [
`~generation.G
reedySearch
DecoderOnlyOutput`
] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
`generation_output`
的对象是 [
`~generation.G
enerate
DecoderOnlyOutput`
] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
-
`sequences`
: 生成的tokens序列
-
`scores`
(可选): 每个生成步骤的语言建模头的预测分数
...
...
@@ -70,25 +70,13 @@ generation_output[:2]
### PyTorch
[[autodoc]] generation.G
reedySearchEncoder
DecoderOutput
[[autodoc]] generation.G
enerate
DecoderO
nlyO
utput
[[autodoc]] generation.G
reedySearch
DecoderO
nlyO
utput
[[autodoc]] generation.G
enerateEncoder
DecoderOutput
[[autodoc]] generation.
SampleEncoder
DecoderOutput
[[autodoc]] generation.
GenerateBeam
DecoderO
nlyO
utput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
...
...
src/transformers/generation/__init__.py
View file @
7e0ddf89
...
...
@@ -94,6 +94,10 @@ else:
"BeamSampleDecoderOnlyOutput"
,
"ContrastiveSearchEncoderDecoderOutput"
,
"ContrastiveSearchDecoderOnlyOutput"
,
"GenerateBeamDecoderOnlyOutput"
,
"GenerateBeamEncoderDecoderOutput"
,
"GenerateDecoderOnlyOutput"
,
"GenerateEncoderDecoderOutput"
,
]
try
:
...
...
@@ -222,6 +226,10 @@ if TYPE_CHECKING:
BeamSearchEncoderDecoderOutput
,
ContrastiveSearchDecoderOnlyOutput
,
ContrastiveSearchEncoderDecoderOutput
,
GenerateBeamDecoderOnlyOutput
,
GenerateBeamEncoderDecoderOutput
,
GenerateDecoderOnlyOutput
,
GenerateEncoderDecoderOutput
,
GenerationMixin
,
GreedySearchDecoderOnlyOutput
,
GreedySearchEncoderDecoderOutput
,
...
...
src/transformers/generation/utils.py
View file @
7e0ddf89
This diff is collapsed.
Click to expand it.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
7e0ddf89
...
...
@@ -1197,18 +1197,14 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
- [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.GenerateBeamDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if
generation_config
is
None
:
...
...
@@ -2244,18 +2240,14 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
- [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.GenerateBeamDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if
generation_config
is
None
:
...
...
src/transformers/models/pop2piano/modeling_pop2piano.py
View file @
7e0ddf89
...
...
@@ -1264,10 +1264,8 @@ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel):
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
if
generation_config
is
None
:
...
...
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
View file @
7e0ddf89
...
...
@@ -2845,11 +2845,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# prepare text_decoder_input_ids
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
...
...
@@ -3134,11 +3131,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
...
...
src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
View file @
7e0ddf89
...
...
@@ -3110,11 +3110,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# prepare text_decoder_input_ids
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
...
...
@@ -3409,11 +3406,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
...
...
src/transformers/models/whisper/modeling_whisper.py
View file @
7e0ddf89
...
...
@@ -1968,10 +1968,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
else only the generated output sequence ids are returned.
...
...
tests/generation/test_utils.py
View file @
7e0ddf89
...
...
@@ -65,6 +65,10 @@ if is_torch_available():
DisjunctiveConstraint
,
ForcedBOSTokenLogitsProcessor
,
ForcedEOSTokenLogitsProcessor
,
GenerateBeamDecoderOnlyOutput
,
GenerateBeamEncoderDecoderOutput
,
GenerateDecoderOnlyOutput
,
GenerateEncoderDecoderOutput
,
GreedySearchDecoderOnlyOutput
,
GreedySearchEncoderDecoderOutput
,
HammingDiversityLogitsProcessor
,
...
...
@@ -730,9 +734,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_greedy
,
GenerateEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_greedy
,
GreedySearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GreedySearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_greedy
,
GreedySearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GreedySearchDecoderOnlyOutput
)
...
...
@@ -848,9 +858,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_sample
,
GenerateEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_sample
,
SampleEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
SampleEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_sample
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_sample
,
SampleDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
SampleDecoderOnlyOutput
)
...
...
@@ -952,9 +968,15 @@ class GenerationTesterMixin:
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchDecoderOnlyOutput
)
...
...
@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_beam_sample
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_sample
,
BeamSampleEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSampleEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_beam_sample
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_sample
,
BeamSampleDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSampleDecoderOnlyOutput
)
...
...
@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_group_beam_search
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_group_beam_search
,
BeamSearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_group_beam_search
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_group_beam_search
,
BeamSearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchDecoderOnlyOutput
)
...
...
@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchDecoderOnlyOutput
)
...
...
tests/models/musicgen/test_modeling_musicgen.py
View file @
7e0ddf89
...
...
@@ -53,12 +53,10 @@ if is_torch_available():
set_seed
,
)
from
transformers.generation
import
(
G
reedySearch
DecoderOnlyOutput
,
G
reedySearch
EncoderDecoderOutput
,
G
enerate
DecoderOnlyOutput
,
G
enerate
EncoderDecoderOutput
,
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
,
SampleDecoderOnlyOutput
,
SampleEncoderDecoderOutput
,
)
...
...
@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
DecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
DecoderOnlyOutput
)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
...
...
@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_sample
,
Sampl
eDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
Sampl
eDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_sample
,
Generat
eDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
Generat
eDecoderOnlyOutput
)
def
test_greedy_generate_stereo_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
...
...
@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
DecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
EncoderDecoderOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
EncoderDecoderOutput
)
def
test_sample_generate
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
...
...
@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_sample
,
Sampl
eEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
Sampl
eEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_sample
,
Generat
eEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
Generat
eEncoderDecoderOutput
)
def
test_generate_without_input_ids
(
self
):
config
,
_
,
_
,
_
,
max_length
=
self
.
_get_input_ids_and_config
()
...
...
@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
EncoderDecoderOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment