Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1a5c500f
Unverified
Commit
1a5c500f
authored
Mar 20, 2024
by
Joao Gante
Committed by
GitHub
Mar 20, 2024
Browse files
Tests: Musicgen tests + `make fix-copies` (#29734)
* make fix-copies * some tests fixed * tests fixed
parent
66ce9593
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
314 deletions
+27
-314
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
...ormers/models/musicgen_melody/modeling_musicgen_melody.py
+2
-2
tests/models/musicgen/test_modeling_musicgen.py
tests/models/musicgen/test_modeling_musicgen.py
+0
-99
tests/models/musicgen_melody/test_modeling_musicgen_melody.py
...s/models/musicgen_melody/test_modeling_musicgen_melody.py
+25
-213
No files found.
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
View file @
1a5c500f
...
...
@@ -1294,7 +1294,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
)
# 11. run greedy search
outputs
=
self
.
greedy_search
(
outputs
=
self
.
_
greedy_search
(
input_ids
,
logits_processor
=
logits_processor
,
stopping_criteria
=
stopping_criteria
,
...
...
@@ -1319,7 +1319,7 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
)
# 12. run sample
outputs
=
self
.
sample
(
outputs
=
self
.
_
sample
(
input_ids
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
...
...
tests/models/musicgen/test_modeling_musicgen.py
View file @
1a5c500f
...
...
@@ -257,105 +257,6 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
warper_kwargs
=
{}
return
process_kwargs
,
warper_kwargs
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former
def
test_greedy_generate_dict_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
False
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former
def
test_greedy_generate_dict_outputs_use_cache
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
True
config
.
is_decoder
=
True
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
def
test_sample_generate
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
process_kwargs
,
logits_warper_kwargs
=
self
.
_get_logits_processor_and_warper_kwargs
(
input_ids
.
shape
[
-
1
],
max_length
=
max_length
,
)
# check `generate()` and `sample()` are equal
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
num_return_sequences
=
3
,
logits_warper_kwargs
=
logits_warper_kwargs
,
process_kwargs
=
process_kwargs
,
)
self
.
assertIsInstance
(
output_generate
,
torch
.
Tensor
)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
def
test_sample_generate_dict_output
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
False
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
process_kwargs
,
logits_warper_kwargs
=
self
.
_get_logits_processor_and_warper_kwargs
(
input_ids
.
shape
[
-
1
],
max_length
=
max_length
,
)
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
num_return_sequences
=
1
,
logits_warper_kwargs
=
logits_warper_kwargs
,
process_kwargs
=
process_kwargs
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
def
test_greedy_generate_stereo_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
...
...
tests/models/musicgen_melody/test_modeling_musicgen_melody.py
View file @
1a5c500f
...
...
@@ -55,8 +55,6 @@ if is_torch_available():
)
from
transformers.generation
import
(
GenerateDecoderOnlyOutput
,
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
,
)
if
is_torchaudio_available
():
...
...
@@ -248,142 +246,24 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
return
config
,
input_ids
,
attention_mask
,
max_length
@
staticmethod
def
_get_logits_processor_and_kwargs
(
def
_get_logits_processor_and_
warper_
kwargs
(
input_length
,
eos_token_id
,
forced_bos_token_id
=
None
,
forced_eos_token_id
=
None
,
max_length
=
None
,
diversity_penalty
=
None
,
):
process_kwargs
=
{
"min_length"
:
input_length
+
1
if
max_length
is
None
else
max_length
-
1
,
}
logits_processor
=
LogitsProcessorList
()
return
process_kwargs
,
logits_processor
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former
def
test_greedy_generate_dict_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
False
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
# override since we don't expect the outputs of `.generate` and `.greedy_search` to be the same, since we perform
# additional post-processing in the former
def
test_greedy_generate_dict_outputs_use_cache
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# enable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
True
config
.
is_decoder
=
True
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
def
test_sample_generate
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
input_ids
.
shape
[
-
1
],
model
.
config
.
eos_token_id
,
forced_bos_token_id
=
model
.
config
.
forced_bos_token_id
,
forced_eos_token_id
=
model
.
config
.
forced_eos_token_id
,
max_length
=
max_length
,
)
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
2
)
# check `generate()` and `sample()` are equal
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
num_return_sequences
=
3
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper_kwargs
=
logits_warper_kwargs
,
process_kwargs
=
process_kwargs
,
)
self
.
assertIsInstance
(
output_sample
,
torch
.
Tensor
)
self
.
assertIsInstance
(
output_generate
,
torch
.
Tensor
)
# override since we don't expect the outputs of `.generate` and `.sample` to be the same, since we perform
# additional post-processing in the former
def
test_sample_generate_dict_output
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
False
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
input_ids
.
shape
[
-
1
],
model
.
config
.
eos_token_id
,
forced_bos_token_id
=
model
.
config
.
forced_bos_token_id
,
forced_eos_token_id
=
model
.
config
.
forced_eos_token_id
,
max_length
=
max_length
,
)
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
1
)
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
max_length
=
max_length
,
num_return_sequences
=
1
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper_kwargs
=
logits_warper_kwargs
,
process_kwargs
=
process_kwargs
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_sample
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
warper_kwargs
=
{}
return
process_kwargs
,
warper_kwargs
def
test_greedy_generate_stereo_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
audio_channels
=
2
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
...
...
@@ -394,9 +274,7 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -817,10 +695,8 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
attention_mask
=
torch
.
ones
((
batch_size
,
sequence_length
),
dtype
=
torch
.
long
)
# generate max 3 tokens
decoder_input_ids
=
inputs_dict
[
"decoder_input_ids"
]
max_length
=
decoder_input_ids
.
shape
[
-
1
]
+
3
decoder_input_ids
=
decoder_input_ids
[:
batch_size
*
config
.
decoder
.
num_codebooks
,
:]
return
config
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
max_length
=
3
return
config
,
input_ids
,
attention_mask
,
max_length
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
# different modalities -> different shapes)
...
...
@@ -829,18 +705,14 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
model
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
,
output_scores
=
False
,
output_attentions
=
False
,
output_hidden_states
=
False
,
return_dict_in_generate
=
False
,
):
logits_process_kwargs
,
logits_processor
=
self
.
_get_logits_processor_and_kwargs
(
logits_process_kwargs
,
_
=
self
.
_get_logits_processor_and_
warper_
kwargs
(
input_ids
.
shape
[
-
1
],
eos_token_id
=
model
.
config
.
eos_token_id
,
forced_bos_token_id
=
model
.
config
.
forced_bos_token_id
,
forced_eos_token_id
=
model
.
config
.
forced_eos_token_id
,
max_length
=
max_length
,
)
...
...
@@ -859,34 +731,17 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
**
model_kwargs
,
)
with
torch
.
no_grad
():
model_kwargs
=
{
"attention_mask"
:
attention_mask
}
if
attention_mask
is
not
None
else
{}
output_greedy
=
model
.
greedy_search
(
decoder_input_ids
,
max_length
=
max_length
,
logits_processor
=
logits_processor
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_scores
=
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
# Ignore copy
**
model_kwargs
,
)
return
output_greedy
,
output_generate
return
output_generate
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
# different modalities -> different shapes)
# Ignore copy
def
_sample_generate
(
self
,
model
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
,
num_return_sequences
,
logits_processor
,
logits_warper
,
logits_warper_kwargs
,
process_kwargs
,
output_scores
=
False
,
...
...
@@ -912,53 +767,31 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
**
model_kwargs
,
)
torch
.
manual_seed
(
0
)
# prevent flaky generation test failures
logits_processor
.
append
(
InfNanRemoveLogitsProcessor
())
with
torch
.
no_grad
():
model_kwargs
=
{
"attention_mask"
:
attention_mask
}
if
attention_mask
is
not
None
else
{}
output_sample
=
model
.
sample
(
decoder_input_ids
.
repeat_interleave
(
num_return_sequences
,
dim
=
0
),
max_length
=
max_length
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
output_scores
=
output_scores
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict_in_generate
=
return_dict_in_generate
,
**
model_kwargs
,
)
return
output_sample
,
output_generate
return
output_generate
@
staticmethod
def
_get_logits_processor_and_kwargs
(
def
_get_logits_processor_and_
warper_
kwargs
(
input_length
,
eos_token_id
,
forced_bos_token_id
=
None
,
forced_eos_token_id
=
None
,
max_length
=
None
,
diversity_penalty
=
None
,
):
process_kwargs
=
{
"min_length"
:
input_length
+
1
if
max_length
is
None
else
max_length
-
1
,
}
logits_processor
=
LogitsProcessorList
()
return
process_kwargs
,
logits_processor
warper_kwargs
=
{}
return
process_kwargs
,
warper_kwargs
def
test_greedy_generate_dict_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
False
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
...
...
@@ -966,7 +799,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
@@ -974,16 +806,15 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def
test_greedy_generate_dict_outputs_use_cache
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# enable cache
config
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
True
config
.
is_decoder
=
True
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
...
...
@@ -991,64 +822,48 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
def
test_sample_generate
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
config
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
process_kwargs
,
logits_
processor
=
self
.
_get_logits_processor_and_kwargs
(
process_kwargs
,
logits_
warper_kwargs
=
self
.
_get_logits_processor_and_
warper_
kwargs
(
input_ids
.
shape
[
-
1
],
model
.
config
.
eos_token_id
,
forced_bos_token_id
=
model
.
config
.
forced_bos_token_id
,
forced_eos_token_id
=
model
.
config
.
forced_eos_token_id
,
max_length
=
max_length
,
)
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
2
)
# check `generate()` and `sample()` are equal
output_sample
,
output_generate
=
self
.
_sample_generate
(
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
num_return_sequences
=
1
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper_kwargs
=
logits_warper_kwargs
,
process_kwargs
=
process_kwargs
,
)
self
.
assertIsInstance
(
output_sample
,
torch
.
Tensor
)
self
.
assertIsInstance
(
output_generate
,
torch
.
Tensor
)
def
test_sample_generate_dict_output
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
# disable cache
config
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
use_cache
=
False
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
process_kwargs
,
logits_
processor
=
self
.
_get_logits_processor_and_kwargs
(
process_kwargs
,
logits_
warper_kwargs
=
self
.
_get_logits_processor_and_
warper_
kwargs
(
input_ids
.
shape
[
-
1
],
model
.
config
.
eos_token_id
,
forced_bos_token_id
=
model
.
config
.
forced_bos_token_id
,
forced_eos_token_id
=
model
.
config
.
forced_eos_token_id
,
max_length
=
max_length
,
)
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
1
)
output_sample
,
output_generate
=
self
.
_sample_generate
(
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
num_return_sequences
=
3
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper_kwargs
=
logits_warper_kwargs
,
process_kwargs
=
process_kwargs
,
output_scores
=
True
,
...
...
@@ -1057,11 +872,10 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_sample
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
def
test_generate_without_input_ids
(
self
):
config
,
_
,
_
,
_
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
_
,
_
,
max_length
=
self
.
_get_input_ids_and_config
()
# if no bos token id => cannot generate from None
if
config
.
bos_token_id
is
None
:
...
...
@@ -1090,15 +904,14 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
def
test_greedy_generate_stereo_outputs
(
self
):
for
model_class
in
self
.
greedy_sample_model_classes
:
config
,
input_ids
,
attention_mask
,
decoder_input_ids
,
max_length
=
self
.
_get_input_ids_and_config
()
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
config
.
audio_channels
=
2
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
.
to
(
torch_device
),
attention_mask
=
attention_mask
.
to
(
torch_device
),
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
...
...
@@ -1106,7 +919,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
return_dict_in_generate
=
True
,
)
self
.
assertIsInstance
(
output_greedy
,
GenerateDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
GenerateDecoderOnlyOutput
)
self
.
assertNotIn
(
config
.
pad_token_id
,
output_generate
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment