Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7e0ddf89
Unverified
Commit
7e0ddf89
authored
Jan 15, 2024
by
Joao Gante
Committed by
GitHub
Jan 15, 2024
Browse files
Generate: consolidate output classes (#28494)
parent
72db39c0
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
176 additions
and
458 deletions
+176
-458
docs/source/en/internal/generation_utils.md
docs/source/en/internal/generation_utils.md
+5
-17
docs/source/ja/internal/generation_utils.md
docs/source/ja/internal/generation_utils.md
+5
-17
docs/source/zh/internal/generation_utils.md
docs/source/zh/internal/generation_utils.md
+5
-17
src/transformers/generation/__init__.py
src/transformers/generation/__init__.py
+8
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+75
-343
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+8
-16
src/transformers/models/pop2piano/modeling_pop2piano.py
src/transformers/models/pop2piano/modeling_pop2piano.py
+2
-4
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
...transformers/models/seamless_m4t/modeling_seamless_m4t.py
+4
-10
src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
...ormers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
+4
-10
src/transformers/models/whisper/modeling_whisper.py
src/transformers/models/whisper/modeling_whisper.py
+2
-4
tests/generation/test_utils.py
tests/generation/test_utils.py
+40
-0
tests/models/musicgen/test_modeling_musicgen.py
tests/models/musicgen/test_modeling_musicgen.py
+18
-20
No files found.
docs/source/en/internal/generation_utils.md
View file @
7e0ddf89
...
...
@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output
=
model
.
generate
(
**
inputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
```
The
`generation_output`
object is a [
`~generation.G
reedySearch
DecoderOnlyOutput`
], as we can
The
`generation_output`
object is a [
`~generation.G
enerate
DecoderOnlyOutput`
], as we can
see in the documentation of that class below, it means it has the following attributes:
-
`sequences`
: the generated sequences of tokens
...
...
@@ -77,25 +77,13 @@ We document here all output types.
### PyTorch
[[autodoc]] generation.G
reedySearchEncoder
DecoderOutput
[[autodoc]] generation.G
enerate
DecoderO
nlyO
utput
[[autodoc]] generation.G
reedySearch
DecoderO
nlyO
utput
[[autodoc]] generation.G
enerateEncoder
DecoderOutput
[[autodoc]] generation.
SampleEncoder
DecoderOutput
[[autodoc]] generation.
GenerateBeam
DecoderO
nlyO
utput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
...
...
docs/source/ja/internal/generation_utils.md
View file @
7e0ddf89
...
...
@@ -45,7 +45,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output
=
model
.
generate
(
**
inputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
```
`generation_output`
オブジェクトは、できる限り [
`~generation.G
reedySearch
DecoderOnlyOutput`
] です。
`generation_output`
オブジェクトは、できる限り [
`~generation.G
enerate
DecoderOnlyOutput`
] です。
以下のそのクラスのドキュメントを参照してください。これは、次の属性があることを意味します。
-
`sequences`
: 生成されたトークンのシーケンス
...
...
@@ -76,25 +76,13 @@ generation_output[:2]
### PyTorch
[[autodoc]] generation.G
reedySearchEncoder
DecoderOutput
[[autodoc]] generation.G
enerate
DecoderO
nlyO
utput
[[autodoc]] generation.G
reedySearch
DecoderO
nlyO
utput
[[autodoc]] generation.G
enerateEncoder
DecoderOutput
[[autodoc]] generation.
SampleEncoder
DecoderOutput
[[autodoc]] generation.
GenerateBeam
DecoderO
nlyO
utput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
...
...
docs/source/zh/internal/generation_utils.md
View file @
7e0ddf89
...
...
@@ -43,7 +43,7 @@ inputs = tokenizer("Hello, my dog is cute and ", return_tensors="pt")
generation_output
=
model
.
generate
(
**
inputs
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
```
`generation_output`
的对象是 [
`~generation.G
reedySearch
DecoderOnlyOutput`
] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
`generation_output`
的对象是 [
`~generation.G
enerate
DecoderOnlyOutput`
] 的一个实例,从该类的文档中我们可以看到,这意味着它具有以下属性:
-
`sequences`
: 生成的tokens序列
-
`scores`
(可选): 每个生成步骤的语言建模头的预测分数
...
...
@@ -70,25 +70,13 @@ generation_output[:2]
### PyTorch
[[autodoc]] generation.G
reedySearchEncoder
DecoderOutput
[[autodoc]] generation.G
enerate
DecoderO
nlyO
utput
[[autodoc]] generation.G
reedySearch
DecoderO
nlyO
utput
[[autodoc]] generation.G
enerateEncoder
DecoderOutput
[[autodoc]] generation.
SampleEncoder
DecoderOutput
[[autodoc]] generation.
GenerateBeam
DecoderO
nlyO
utput
[[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.GenerateBeamEncoderDecoderOutput
### TensorFlow
...
...
src/transformers/generation/__init__.py
View file @
7e0ddf89
...
...
@@ -94,6 +94,10 @@ else:
"BeamSampleDecoderOnlyOutput"
,
"ContrastiveSearchEncoderDecoderOutput"
,
"ContrastiveSearchDecoderOnlyOutput"
,
"GenerateBeamDecoderOnlyOutput"
,
"GenerateBeamEncoderDecoderOutput"
,
"GenerateDecoderOnlyOutput"
,
"GenerateEncoderDecoderOutput"
,
]
try
:
...
...
@@ -222,6 +226,10 @@ if TYPE_CHECKING:
BeamSearchEncoderDecoderOutput
,
ContrastiveSearchDecoderOnlyOutput
,
ContrastiveSearchEncoderDecoderOutput
,
GenerateBeamDecoderOnlyOutput
,
GenerateBeamEncoderDecoderOutput
,
GenerateDecoderOnlyOutput
,
GenerateEncoderDecoderOutput
,
GenerationMixin
,
GreedySearchDecoderOnlyOutput
,
GreedySearchEncoderDecoderOutput
,
...
...
src/transformers/generation/utils.py
View file @
7e0ddf89
This diff is collapsed.
Click to expand it.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
7e0ddf89
...
...
@@ -1197,18 +1197,14 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
- [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.GenerateBeamDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if
generation_config
is
None
:
...
...
@@ -2244,18 +2240,14 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
- [`~generation.GenerateDecoderOnlyOutput`],
- [`~generation.GenerateBeamDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the resulting objects
if
generation_config
is
None
:
...
...
src/transformers/models/pop2piano/modeling_pop2piano.py
View file @
7e0ddf89
...
...
@@ -1264,10 +1264,8 @@ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel):
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
if
generation_config
is
None
:
...
...
src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
View file @
7e0ddf89
...
...
@@ -2845,11 +2845,8 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# prepare text_decoder_input_ids
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
...
...
@@ -3134,11 +3131,8 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
...
...
src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
View file @
7e0ddf89
...
...
@@ -3110,11 +3110,8 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# prepare text_decoder_input_ids
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
...
...
@@ -3409,11 +3406,8 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel):
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`. The possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
text_decoder_input_ids
=
kwargs
.
pop
(
"decoder_input_ids"
,
None
)
# overwrite text_decoder_input_ids if tgt_lang is passed. The latter gets priority over decoder_input_ids.
...
...
src/transformers/models/whisper/modeling_whisper.py
View file @
7e0ddf89
...
...
@@ -1968,10 +1968,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
else only the generated output sequence ids are returned.
...
...
tests/generation/test_utils.py
View file @
7e0ddf89
...
...
@@ -65,6 +65,10 @@ if is_torch_available():
DisjunctiveConstraint
,
ForcedBOSTokenLogitsProcessor
,
ForcedEOSTokenLogitsProcessor
,
GenerateBeamDecoderOnlyOutput
,
GenerateBeamEncoderDecoderOutput
,
GenerateDecoderOnlyOutput
,
GenerateEncoderDecoderOutput
,
GreedySearchDecoderOnlyOutput
,
GreedySearchEncoderDecoderOutput
,
HammingDiversityLogitsProcessor
,
...
...
@@ -730,9 +734,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_greedy
,
GenerateEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_greedy
,
GreedySearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GreedySearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_greedy
,
GreedySearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GreedySearchDecoderOnlyOutput
)
...
...
@@ -848,9 +858,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_sample
,
GenerateEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_sample
,
SampleEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
SampleEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_sample
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_sample
,
SampleDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
SampleDecoderOnlyOutput
)
...
...
@@ -952,9 +968,15 @@ class GenerationTesterMixin:
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchDecoderOnlyOutput
)
...
...
@@ -1109,9 +1131,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_beam_sample
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_sample
,
BeamSampleEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSampleEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_beam_sample
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_sample
,
BeamSampleDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSampleDecoderOnlyOutput
)
...
...
@@ -1238,9 +1266,15 @@ class GenerationTesterMixin:
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_group_beam_search
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_group_beam_search
,
BeamSearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_group_beam_search
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_group_beam_search
,
BeamSearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchDecoderOnlyOutput
)
...
...
@@ -1390,9 +1424,15 @@ class GenerationTesterMixin:
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamEncoderDecoderOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_beam_search
,
GenerateBeamDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateBeamDecoderOnlyOutput
)
# Retrocompatibility check
self
.
assertIsInstance
(
output_beam_search
,
BeamSearchDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSearchDecoderOnlyOutput
)
...
...
tests/models/musicgen/test_modeling_musicgen.py
View file @
7e0ddf89
...
...
@@ -53,12 +53,10 @@ if is_torch_available():
set_seed
,
)
from
transformers.generation
import
(
G
reedySearch
DecoderOnlyOutput
,
G
reedySearch
EncoderDecoderOutput
,
G
enerate
DecoderOnlyOutput
,
G
enerate
EncoderDecoderOutput
,
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
,
SampleDecoderOnlyOutput
,
SampleEncoderDecoderOutput
,
)
...
...
@@ -282,8 +280,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
DecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -308,8 +306,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
DecoderOnlyOutput
)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
...
...
@@ -376,8 +374,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_sample
,
Sampl
eDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
Sampl
eDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_sample
,
Generat
eDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
Generat
eDecoderOnlyOutput
)
def
test_greedy_generate_stereo_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
...
...
@@ -395,8 +393,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
DecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
DecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -1001,8 +999,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
EncoderDecoderOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -1026,8 +1024,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
EncoderDecoderOutput
)
def
test_sample_generate
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
...
...
@@ -1092,8 +1090,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_sample
,
Sampl
eEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
Sampl
eEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_sample
,
Generat
eEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
Generat
eEncoderDecoderOutput
)
def
test_generate_without_input_ids
(
self
):
config
,
_
,
_
,
_
,
max_length
=
self
.
_get_input_ids_and_config
()
...
...
@@ -1141,8 +1139,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
reedySearch
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_greedy
,
G
enerate
EncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
G
enerate
EncoderDecoderOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment