Unverified Commit f16ff0f0 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

MusicGen Update (#27084)

* [MusicGen] Add stereo model

* safe serialization

* Update src/transformers/models/musicgen/modeling_musicgen.py

* split over 2 lines

* fix slow tests on cuda
parent 5ef650b0
...@@ -57,6 +57,11 @@ Generation is limited by the sinusoidal positional embeddings to 30 second input ...@@ -57,6 +57,11 @@ Generation is limited by the sinusoidal positional embeddings to 30 second input
than 30 seconds of audio (1503 tokens), and input audio passed by Audio-Prompted Generation contributes to this limit so, than 30 seconds of audio (1503 tokens), and input audio passed by Audio-Prompted Generation contributes to this limit so,
given an input of 20 seconds of audio, MusicGen cannot generate more than 10 seconds of additional audio. given an input of 20 seconds of audio, MusicGen cannot generate more than 10 seconds of additional audio.
Transformers supports both mono (1-channel) and stereo (2-channel) variants of MusicGen. The mono channel versions
generate a single set of codebooks. The stereo versions generate 2 sets of codebooks, 1 for each channel (left/right),
and each set of codebooks is decoded independently through the audio compression model. The audio streams for each
channel are combined to give the final stereo output.
### Unconditional Generation ### Unconditional Generation
The inputs for unconditional (or 'null') generation can be obtained through the method The inputs for unconditional (or 'null') generation can be obtained through the method
......
...@@ -75,6 +75,9 @@ class MusicgenDecoderConfig(PretrainedConfig): ...@@ -75,6 +75,9 @@ class MusicgenDecoderConfig(PretrainedConfig):
The number of parallel codebooks forwarded to the model. The number of parallel codebooks forwarded to the model.
tie_word_embeddings(`bool`, *optional*, defaults to `False`): tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether input and output word embeddings should be tied. Whether input and output word embeddings should be tied.
audio_channels (`int`, *optional*, defaults to 1
Number of channels in the audio data. Either 1 for mono or 2 for stereo. Stereo models generate a separate
audio stream for the left/right output channels. Mono models generate a single audio stream output.
""" """
model_type = "musicgen_decoder" model_type = "musicgen_decoder"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -96,6 +99,7 @@ class MusicgenDecoderConfig(PretrainedConfig): ...@@ -96,6 +99,7 @@ class MusicgenDecoderConfig(PretrainedConfig):
initializer_factor=0.02, initializer_factor=0.02,
scale_embedding=False, scale_embedding=False,
num_codebooks=4, num_codebooks=4,
audio_channels=1,
pad_token_id=2048, pad_token_id=2048,
bos_token_id=2048, bos_token_id=2048,
eos_token_id=None, eos_token_id=None,
...@@ -117,6 +121,11 @@ class MusicgenDecoderConfig(PretrainedConfig): ...@@ -117,6 +121,11 @@ class MusicgenDecoderConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.num_codebooks = num_codebooks self.num_codebooks = num_codebooks
if audio_channels not in [1, 2]:
raise ValueError(f"Expected 1 (mono) or 2 (stereo) audio channels, got {audio_channels} channels.")
self.audio_channels = audio_channels
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
......
...@@ -88,32 +88,48 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict, ...@@ -88,32 +88,48 @@ def rename_state_dict(state_dict: OrderedDict, hidden_size: int) -> Tuple[Dict,
def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig: def decoder_config_from_checkpoint(checkpoint: str) -> MusicgenDecoderConfig:
if checkpoint == "small": if checkpoint == "small" or checkpoint == "facebook/musicgen-stereo-small":
# default config values # default config values
hidden_size = 1024 hidden_size = 1024
num_hidden_layers = 24 num_hidden_layers = 24
num_attention_heads = 16 num_attention_heads = 16
elif checkpoint == "medium": elif checkpoint == "medium" or checkpoint == "facebook/musicgen-stereo-medium":
hidden_size = 1536 hidden_size = 1536
num_hidden_layers = 48 num_hidden_layers = 48
num_attention_heads = 24 num_attention_heads = 24
elif checkpoint == "large": elif checkpoint == "large" or checkpoint == "facebook/musicgen-stereo-large":
hidden_size = 2048 hidden_size = 2048
num_hidden_layers = 48 num_hidden_layers = 48
num_attention_heads = 32 num_attention_heads = 32
else: else:
raise ValueError(f"Checkpoint should be one of `['small', 'medium', 'large']`, got {checkpoint}.") raise ValueError(
"Checkpoint should be one of `['small', 'medium', 'large']` for the mono checkpoints, "
"or `['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
f"for the stereo checkpoints, got {checkpoint}."
)
if "stereo" in checkpoint:
audio_channels = 2
num_codebooks = 8
else:
audio_channels = 1
num_codebooks = 4
config = MusicgenDecoderConfig( config = MusicgenDecoderConfig(
hidden_size=hidden_size, hidden_size=hidden_size,
ffn_dim=hidden_size * 4, ffn_dim=hidden_size * 4,
num_hidden_layers=num_hidden_layers, num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
num_codebooks=num_codebooks,
audio_channels=audio_channels,
) )
return config return config
@torch.no_grad() @torch.no_grad()
def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu"): def convert_musicgen_checkpoint(
checkpoint, pytorch_dump_folder=None, repo_id=None, device="cpu", safe_serialization=False
):
fairseq_model = MusicGen.get_pretrained(checkpoint, device=device) fairseq_model = MusicGen.get_pretrained(checkpoint, device=device)
decoder_config = decoder_config_from_checkpoint(checkpoint) decoder_config = decoder_config_from_checkpoint(checkpoint)
...@@ -146,18 +162,20 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No ...@@ -146,18 +162,20 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict) model.enc_to_dec_proj.load_state_dict(enc_dec_proj_state_dict)
# check we can do a forward pass # check we can do a forward pass
input_ids = torch.arange(0, 8, dtype=torch.long).reshape(2, -1) input_ids = torch.arange(0, 2 * decoder_config.num_codebooks, dtype=torch.long).reshape(2, -1)
decoder_input_ids = input_ids.reshape(2 * 4, -1) decoder_input_ids = input_ids.reshape(2 * decoder_config.num_codebooks, -1)
with torch.no_grad(): with torch.no_grad():
logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits logits = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids).logits
if logits.shape != (8, 1, 2048): if logits.shape != (2 * decoder_config.num_codebooks, 1, 2048):
raise ValueError("Incorrect shape for logits") raise ValueError("Incorrect shape for logits")
# now construct the processor # now construct the processor
tokenizer = AutoTokenizer.from_pretrained("t5-base") tokenizer = AutoTokenizer.from_pretrained("t5-base")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/encodec_32khz", padding_side="left") feature_extractor = AutoFeatureExtractor.from_pretrained(
"facebook/encodec_32khz", padding_side="left", feature_size=decoder_config.audio_channels
)
processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer) processor = MusicgenProcessor(feature_extractor=feature_extractor, tokenizer=tokenizer)
...@@ -173,12 +191,12 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No ...@@ -173,12 +191,12 @@ def convert_musicgen_checkpoint(checkpoint, pytorch_dump_folder=None, repo_id=No
if pytorch_dump_folder is not None: if pytorch_dump_folder is not None:
Path(pytorch_dump_folder).mkdir(exist_ok=True) Path(pytorch_dump_folder).mkdir(exist_ok=True)
logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}") logger.info(f"Saving model {checkpoint} to {pytorch_dump_folder}")
model.save_pretrained(pytorch_dump_folder) model.save_pretrained(pytorch_dump_folder, safe_serialization=safe_serialization)
processor.save_pretrained(pytorch_dump_folder) processor.save_pretrained(pytorch_dump_folder)
if repo_id: if repo_id:
logger.info(f"Pushing model {checkpoint} to {repo_id}") logger.info(f"Pushing model {checkpoint} to {repo_id}")
model.push_to_hub(repo_id) model.push_to_hub(repo_id, safe_serialization=safe_serialization)
processor.push_to_hub(repo_id) processor.push_to_hub(repo_id)
...@@ -189,7 +207,10 @@ if __name__ == "__main__": ...@@ -189,7 +207,10 @@ if __name__ == "__main__":
"--checkpoint", "--checkpoint",
default="small", default="small",
type=str, type=str,
help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: `['small', 'medium', 'large']`.", help="Checkpoint size of the MusicGen model you'd like to convert. Can be one of: "
"`['small', 'medium', 'large']` for the mono checkpoints, or "
"`['facebook/musicgen-stereo-small', 'facebook/musicgen-stereo-medium', 'facebook/musicgen-stereo-large']` "
"for the stereo checkpoints.",
) )
parser.add_argument( parser.add_argument(
"--pytorch_dump_folder", "--pytorch_dump_folder",
...@@ -204,6 +225,11 @@ if __name__ == "__main__": ...@@ -204,6 +225,11 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda." "--device", default="cpu", type=str, help="Torch device to run the conversion, either cpu or cuda."
) )
parser.add_argument(
"--safe_serialization",
action="store_true",
help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).",
)
args = parser.parse_args() args = parser.parse_args()
convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub) convert_musicgen_checkpoint(args.checkpoint, args.pytorch_dump_folder, args.push_to_hub)
...@@ -1077,21 +1077,33 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel): ...@@ -1077,21 +1077,33 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1 torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
) )
channel_codebooks = num_codebooks // 2 if self.config.audio_channels == 2 else num_codebooks
# we only apply the mask if we have a large enough seq len - otherwise we return as is # we only apply the mask if we have a large enough seq len - otherwise we return as is
if max_length < 2 * num_codebooks - 1: if max_length < 2 * channel_codebooks - 1:
return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1) return input_ids.reshape(bsz * num_codebooks, -1), input_ids_shifted.reshape(bsz * num_codebooks, -1)
# fill the shifted ids with the prompt entries, offset by the codebook idx # fill the shifted ids with the prompt entries, offset by the codebook idx
for codebook in range(num_codebooks): for codebook in range(channel_codebooks):
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook] if self.config.audio_channels == 1:
# mono channel - loop over the codebooks one-by-one
input_ids_shifted[:, codebook, codebook : seq_len + codebook] = input_ids[:, codebook]
else:
# left/right channels are interleaved in the generated codebooks, so handle one then the other
input_ids_shifted[:, 2 * codebook, codebook : seq_len + codebook] = input_ids[:, 2 * codebook]
input_ids_shifted[:, 2 * codebook + 1, codebook : seq_len + codebook] = input_ids[:, 2 * codebook + 1]
# construct a pattern mask that indicates the positions of padding tokens for each codebook # construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding) # first fill the upper triangular part (the EOS padding)
delay_pattern = torch.triu( delay_pattern = torch.triu(
torch.ones((num_codebooks, max_length), dtype=torch.bool), diagonal=max_length - num_codebooks + 1 torch.ones((channel_codebooks, max_length), dtype=torch.bool), diagonal=max_length - channel_codebooks + 1
) )
# then fill the lower triangular part (the BOS padding) # then fill the lower triangular part (the BOS padding)
delay_pattern = delay_pattern + torch.tril(torch.ones((num_codebooks, max_length), dtype=torch.bool)) delay_pattern = delay_pattern + torch.tril(torch.ones((channel_codebooks, max_length), dtype=torch.bool))
if self.config.audio_channels == 2:
# for left/right channel we need to duplicate every row of the pattern mask in an interleaved fashion
delay_pattern = delay_pattern.repeat_interleave(2, dim=0)
mask = ~delay_pattern.to(input_ids.device) mask = ~delay_pattern.to(input_ids.device)
input_ids = mask * input_ids_shifted + ~mask * pad_token_id input_ids = mask * input_ids_shifted + ~mask * pad_token_id
...@@ -1856,6 +1868,11 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -1856,6 +1868,11 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is " f"Expected 1 frame in the audio code outputs, got {frames} frames. Ensure chunking is "
"disabled by setting `chunk_length=None` in the audio encoder." "disabled by setting `chunk_length=None` in the audio encoder."
) )
if self.config.audio_channels == 2 and audio_codes.shape[2] == self.decoder.num_codebooks // 2:
# mono input through encodec that we convert to stereo
audio_codes = audio_codes.repeat_interleave(2, dim=2)
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
# Decode # Decode
...@@ -2074,12 +2091,42 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2074,12 +2091,42 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
# 3. make sure that encoder returns `ModelOutput` # 3. make sure that encoder returns `ModelOutput`
model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name model_input_name = model_input_name if model_input_name is not None else self.audio_encoder.main_input_name
encoder_kwargs["return_dict"] = True encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = input_values
audio_encoder_outputs = encoder.encode(**encoder_kwargs) if self.decoder.config.audio_channels == 1:
encoder_kwargs[model_input_name] = input_values
audio_encoder_outputs = encoder.encode(**encoder_kwargs)
audio_codes = audio_encoder_outputs.audio_codes
audio_scales = audio_encoder_outputs.audio_scales
audio_codes = audio_encoder_outputs.audio_codes frames, bsz, codebooks, seq_len = audio_codes.shape
frames, bsz, codebooks, seq_len = audio_codes.shape
else:
if input_values.shape[1] != 2:
raise ValueError(
f"Expected stereo audio (2-channels) but example has {input_values.shape[1]} channel."
)
encoder_kwargs[model_input_name] = input_values[:, :1, :]
audio_encoder_outputs_left = encoder.encode(**encoder_kwargs)
audio_codes_left = audio_encoder_outputs_left.audio_codes
audio_scales_left = audio_encoder_outputs_left.audio_scales
encoder_kwargs[model_input_name] = input_values[:, 1:, :]
audio_encoder_outputs_right = encoder.encode(**encoder_kwargs)
audio_codes_right = audio_encoder_outputs_right.audio_codes
audio_scales_right = audio_encoder_outputs_right.audio_scales
frames, bsz, codebooks, seq_len = audio_codes_left.shape
# copy alternating left/right channel codes into stereo codebook
audio_codes = audio_codes_left.new_ones((frames, bsz, 2 * codebooks, seq_len))
audio_codes[:, :, ::2, :] = audio_codes_left
audio_codes[:, :, 1::2, :] = audio_codes_right
if audio_scales_left != [None] or audio_scales_right != [None]:
audio_scales = torch.stack([audio_scales_left, audio_scales_right], dim=1)
else:
audio_scales = [None] * bsz
if frames != 1: if frames != 1:
raise ValueError( raise ValueError(
...@@ -2090,7 +2137,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2090,7 +2137,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len) decoder_input_ids = audio_codes[0, ...].reshape(bsz * self.decoder.num_codebooks, seq_len)
model_kwargs["decoder_input_ids"] = decoder_input_ids model_kwargs["decoder_input_ids"] = decoder_input_ids
model_kwargs["audio_scales"] = audio_encoder_outputs.audio_scales model_kwargs["audio_scales"] = audio_scales
return model_kwargs return model_kwargs
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
...@@ -2433,16 +2480,25 @@ class MusicgenForConditionalGeneration(PreTrainedModel): ...@@ -2433,16 +2480,25 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
if audio_scales is None: if audio_scales is None:
audio_scales = [None] * batch_size audio_scales = [None] * batch_size
output_values = self.audio_encoder.decode( if self.decoder.config.audio_channels == 1:
output_ids, output_values = self.audio_encoder.decode(
audio_scales=audio_scales, output_ids,
) audio_scales=audio_scales,
).audio_values
else:
codec_outputs_left = self.audio_encoder.decode(output_ids[:, :, ::2, :], audio_scales=audio_scales)
output_values_left = codec_outputs_left.audio_values
codec_outputs_right = self.audio_encoder.decode(output_ids[:, :, 1::2, :], audio_scales=audio_scales)
output_values_right = codec_outputs_right.audio_values
output_values = torch.cat([output_values_left, output_values_right], dim=1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_values.audio_values outputs.sequences = output_values
return outputs return outputs
else: else:
return output_values.audio_values return output_values
def get_unconditional_inputs(self, num_samples=1): def get_unconditional_inputs(self, num_samples=1):
""" """
......
...@@ -379,6 +379,27 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -379,6 +379,27 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
self.assertIsInstance(output_sample, SampleDecoderOnlyOutput) self.assertIsInstance(output_sample, SampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchDecoderOnlyOutput)
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def prepare_musicgen_inputs_dict( def prepare_musicgen_inputs_dict(
config, config,
...@@ -1102,6 +1123,29 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1102,6 +1123,29 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10 input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
) )
def test_greedy_generate_stereo_outputs(self):
for model_class in self.greedy_sample_model_classes:
config, input_ids, attention_mask, decoder_input_ids, max_length = self._get_input_ids_and_config()
config.audio_channels = 2
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
)
self.assertIsInstance(output_greedy, GreedySearchEncoderDecoderOutput)
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
self.assertNotIn(config.pad_token_id, output_generate)
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency.""" """Produces a series of 'bip bip' sounds at a given frequency."""
...@@ -1357,3 +1401,79 @@ class MusicgenIntegrationTests(unittest.TestCase): ...@@ -1357,3 +1401,79 @@ class MusicgenIntegrationTests(unittest.TestCase):
output_values.shape == (2, 1, 36480) output_values.shape == (2, 1, 36480)
) # input values take shape 32000 and we generate from there ) # input values take shape 32000 and we generate from there
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4)) self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES, atol=1e-4))
@require_torch
class MusicgenStereoIntegrationTests(unittest.TestCase):
@cached_property
def model(self):
return MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-stereo-small").to(torch_device)
@cached_property
def processor(self):
return MusicgenProcessor.from_pretrained("facebook/musicgen-stereo-small")
@slow
def test_generate_unconditional_greedy(self):
model = self.model
# only generate 1 sample with greedy - since it's deterministic all elements of the batch will be the same
unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
unconditional_inputs = place_dict_on_device(unconditional_inputs, device=torch_device)
output_values = model.generate(**unconditional_inputs, do_sample=False, max_new_tokens=12)
# fmt: off
EXPECTED_VALUES_LEFT = torch.tensor(
[
0.0017, 0.0004, 0.0004, 0.0005, 0.0002, 0.0002, -0.0002, -0.0013,
-0.0010, -0.0015, -0.0018, -0.0032, -0.0060, -0.0082, -0.0096, -0.0099,
]
)
EXPECTED_VALUES_RIGHT = torch.tensor(
[
0.0038, 0.0028, 0.0031, 0.0032, 0.0031, 0.0032, 0.0030, 0.0019,
0.0021, 0.0015, 0.0009, -0.0008, -0.0040, -0.0067, -0.0087, -0.0096,
]
)
# fmt: on
# (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (1, 2, 5760))
self.assertTrue(torch.allclose(output_values[0, 0, :16].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
self.assertTrue(torch.allclose(output_values[0, 1, :16].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
@slow
def test_generate_text_audio_prompt(self):
model = self.model
processor = self.processor
# create stereo inputs
audio = [get_bip_bip(duration=0.5)[None, :].repeat(2, 0), get_bip_bip(duration=1.0)[None, :].repeat(2, 0)]
text = ["80s music", "Club techno"]
inputs = processor(audio=audio, text=text, padding=True, return_tensors="pt")
inputs = place_dict_on_device(inputs, device=torch_device)
output_values = model.generate(**inputs, do_sample=False, guidance_scale=3.0, max_new_tokens=12)
# fmt: off
EXPECTED_VALUES_LEFT = torch.tensor(
[
0.2535, 0.2008, 0.1471, 0.0896, 0.0306, -0.0200, -0.0501, -0.0728,
-0.0832, -0.0856, -0.0867, -0.0884, -0.0864, -0.0866, -0.0744, -0.0430,
]
)
EXPECTED_VALUES_RIGHT = torch.tensor(
[
0.1695, 0.1213, 0.0732, 0.0239, -0.0264, -0.0705, -0.0935, -0.1103,
-0.1163, -0.1139, -0.1104, -0.1082, -0.1027, -0.1004, -0.0900, -0.0614,
]
)
# fmt: on
# (bsz, channels, seq_len)
self.assertTrue(output_values.shape == (2, 2, 37760))
# input values take shape 32000 and we generate from there - we check the last (generated) values
self.assertTrue(torch.allclose(output_values[0, 0, -16:].cpu(), EXPECTED_VALUES_LEFT, atol=1e-4))
self.assertTrue(torch.allclose(output_values[0, 1, -16:].cpu(), EXPECTED_VALUES_RIGHT, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment