Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
03f98f96
Unverified
Commit
03f98f96
authored
Jul 28, 2023
by
Sanchit Gandhi
Committed by
GitHub
Jul 28, 2023
Browse files
[MusicGen] Fix integration tests (#25169)
* move to device * update with cuda values * fix fp16 * more rigorous
parent
c90e14fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
26 deletions
+27
-26
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+1
-4
tests/models/musicgen/test_modeling_musicgen.py
tests/models/musicgen/test_modeling_musicgen.py
+26
-22
No files found.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
03f98f96
...
...
@@ -773,10 +773,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
if
past_key_values
is
not
None
else
0
if
inputs_embeds
is
None
:
inputs_embeds
=
torch
.
zeros
((
bsz
,
seq_len
,
self
.
d_model
),
device
=
input_ids
.
device
)
for
codebook
in
range
(
num_codebooks
):
inputs_embeds
+=
self
.
embed_tokens
[
codebook
](
input
[:,
codebook
])
inputs_embeds
=
sum
([
self
.
embed_tokens
[
codebook
](
input
[:,
codebook
])
for
codebook
in
range
(
num_codebooks
)])
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
...
...
tests/models/musicgen/test_modeling_musicgen.py
View file @
03f98f96
...
...
@@ -267,8 +267,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
...
...
@@ -293,8 +293,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
output_scores
=
True
,
output_hidden_states
=
True
,
...
...
@@ -324,8 +324,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# check `generate()` and `sample()` are equal
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
num_return_sequences
=
3
,
logits_processor
=
logits_processor
,
...
...
@@ -356,8 +356,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
max_length
=
max_length
,
num_return_sequences
=
1
,
logits_processor
=
logits_processor
,
...
...
@@ -964,8 +964,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
output_scores
=
True
,
...
...
@@ -989,8 +989,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model
=
model_class
(
config
).
to
(
torch_device
).
eval
()
output_greedy
,
output_generate
=
self
.
_greedy_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
output_scores
=
True
,
...
...
@@ -1019,8 +1019,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# check `generate()` and `sample()` are equal
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
num_return_sequences
=
1
,
...
...
@@ -1050,8 +1050,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_sample
,
output_generate
=
self
.
_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
input_ids
=
input_ids
.
to
(
torch_device
)
,
attention_mask
=
attention_mask
.
to
(
torch_device
)
,
decoder_input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
num_return_sequences
=
3
,
...
...
@@ -1089,8 +1089,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model
=
model_class
(
config
).
eval
().
to
(
torch_device
)
if
torch_device
==
"cuda"
:
model
.
half
()
model
.
generate
(
**
input_dict
,
max_new_tokens
=
10
)
model
.
generate
(
**
input_dict
,
do_sample
=
True
,
max_new_tokens
=
10
)
# greedy
model
.
generate
(
input_dict
[
"input_ids"
],
attention_mask
=
input_dict
[
"attention_mask"
],
max_new_tokens
=
10
)
# sampling
model
.
generate
(
input_dict
[
"input_ids"
],
attention_mask
=
input_dict
[
"attention_mask"
],
do_sample
=
True
,
max_new_tokens
=
10
)
def
get_bip_bip
(
bip_duration
=
0.125
,
duration
=
0.5
,
sample_rate
=
32000
):
...
...
@@ -1230,8 +1234,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_VALUES
=
torch
.
tensor
(
[
0.0
765
,
0.0758
,
0.0
749
,
0.07
5
9
,
0.0
759
,
0.0
771
,
0.0
77
5
,
0.0
760
,
0.0
762
,
0.0
765
,
0.0
767
,
0.0
760
,
0.0
738
,
0.0
714
,
0.0
713
,
0.0
730
,
-
0.0
099
,
-
0.0
140
,
0.
0
079
,
0.0
080
,
-
0.0
046
,
0.0
06
5
,
-
0.0
068
,
-
0.0185
,
0.0
105
,
0.0
059
,
0.0
329
,
0.0
249
,
-
0.0
204
,
-
0.0
341
,
-
0.0
465
,
0.0
053
,
]
)
# fmt: on
...
...
@@ -1312,8 +1316,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_VALUES
=
torch
.
tensor
(
[
-
0.0
047
,
-
0.0
09
4
,
-
0.00
28
,
-
0.00
1
8
,
-
0.00
57
,
-
0.00
07
,
-
0.010
4
,
-
0.02
11
,
-
0.00
97
,
-
0.0
150
,
-
0.0
066
,
-
0.00
0
4
,
-
0.02
0
1
,
-
0.0
325
,
-
0.0
326
,
-
0.0
098
,
-
0.0
111
,
-
0.0
15
4
,
0.00
47
,
0.00
5
8
,
-
0.00
68
,
0.00
12
,
-
0.010
9
,
-
0.02
29
,
0.00
10
,
-
0.0
038
,
0.0
167
,
0.004
2
,
-
0.0
4
21
,
-
0.0
610
,
-
0.0
764
,
-
0.0
326
,
]
)
# fmt: on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment