Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
03f98f96
Unverified
Commit
03f98f96
authored
Jul 28, 2023
by
Sanchit Gandhi
Committed by
GitHub
Jul 28, 2023
Browse files
[MusicGen] Fix integration tests (#25169)
* move to device * update with cuda values * fix fp16 * more rigorous
parent
c90e14fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
26 deletions
+27
-26
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+1
-4
tests/models/musicgen/test_modeling_musicgen.py
tests/models/musicgen/test_modeling_musicgen.py
+26
-22
No files found.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
03f98f96
...
...
@@ -773,10 +773,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
if
past_key_values
is
not
None
else
0
if
inputs_embeds
is
None
:
inputs_embeds
=
torch
.
zeros
((
bsz
,
seq_len
,
self
.
d_model
),
device
=
input_ids
.
device
)
for
codebook
in
range
(
num_codebooks
):
inputs_embeds
+=
self
.
embed_tokens
[
codebook
](
input
[:,
codebook
])
inputs_embeds
=
sum
([
self
.
embed_tokens
[
codebook
](
input
[:,
codebook
])
for
codebook
in
range
(
num_codebooks
)])
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
...
...
tests/models/musicgen/test_modeling_musicgen.py
View file @
03f98f96
...
...
@@ -267,8 +267,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
...
...
@@ -293,8 +293,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
...
...
@@ -324,8 +324,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# check `generate()` and `sample()` are equal
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
num_return_sequences
=
3
,
logits_processor
=
logits_processor
,
...
...
@@ -356,8 +356,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
num_return_sequences
=
1
,
logits_processor
=
logits_processor
,
...
...
@@ -964,8 +964,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
output_scores
=
True
,
...
...
@@ -989,8 +989,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
output_scores
=
True
,
...
...
@@ -1019,8 +1019,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# check `generate()` and `sample()` are equal
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
num_return_sequences
=
1
,
...
...
@@ -1050,8 +1050,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
num_return_sequences
=
3
,
...
...
@@ -1089,8 +1089,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model
=
model_class
(
config
).
eval
().
to
(
torch_device
)
if
torch_device
==
"cuda"
:
model
.
half
()
model
.
generate
(
**
input_dict
,
max_new_tokens
=
10
)
model
.
generate
(
**
input_dict
,
do_sample
=
True
,
max_new_tokens
=
10
)
# greedy
model
.
generate
(
input_dict
[
"input_ids"
],
attention_mask
=
input_dict
[
"attention_mask"
],
max_new_tokens
=
10
)
# sampling
model
.
generate
(
input_dict
[
"input_ids"
],
attention_mask
=
input_dict
[
"attention_mask"
],
do_sample
=
True
,
max_new_tokens
=
10
)
def
get_bip_bip
(
bip_duration
=
0.125
,
duration
=
0.5
,
sample_rate
=
32000
):
...
...
@@ -1230,8 +1234,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_VALUES
=
torch
.
tensor
(
[
0.0
765
,
0.0758
,
0.0
749
,
0.07
5
9
,
0.0
759
,
0.0
771
,
0.0
77
5
,
0.0
760
,
0.0
762
,
0.0
765
,
0.0
767
,
0.0
760
,
0.0
738
,
0.0
714
,
0.0
713
,
0.0
730
,
-
0.0
099
,
-
0.0
140
,
0.
0
079
,
0.0
080
,
-
0.0
046
,
0.0
06
5
,
-
0.0
068
,
-
0.0185
,
0.0
105
,
0.0
059
,
0.0
329
,
0.0
249
,
-
0.0
204
,
-
0.0
341
,
-
0.0
465
,
0.0
053
,
]
)
# fmt: on
...
...
@@ -1312,8 +1316,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_VALUES
=
torch
.
tensor
(
[
-
0.0
047
,
-
0.0
09
4
,
-
0.00
28
,
-
0.00
1
8
,
-
0.00
57
,
-
0.00
07
,
-
0.010
4
,
-
0.02
11
,
-
0.00
97
,
-
0.0
150
,
-
0.0
066
,
-
0.00
0
4
,
-
0.02
0
1
,
-
0.0
325
,
-
0.0
326
,
-
0.0
098
,
-
0.0
111
,
-
0.0
15
4
,
0.00
47
,
0.00
5
8
,
-
0.00
68
,
0.00
12
,
-
0.010
9
,
-
0.02
29
,
0.00
10
,
-
0.0
038
,
0.0
167
,
0.004
2
,
-
0.0
4
21
,
-
0.0
610
,
-
0.0
764
,
-
0.0
326
,
]
)
# fmt: on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment