Unverified Commit 03f98f96 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[MusicGen] Fix integration tests (#25169)

* move to device

* update with cuda values

* fix fp16

* more rigorous
parent c90e14fb
...@@ -773,10 +773,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel): ...@@ -773,10 +773,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = torch.zeros((bsz, seq_len, self.d_model), device=input_ids.device) inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
for codebook in range(num_codebooks):
inputs_embeds += self.embed_tokens[codebook](input[:, codebook])
attention_mask = self._prepare_decoder_attention_mask( attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
......
...@@ -267,8 +267,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -267,8 +267,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_greedy, output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
max_length=max_length, max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
...@@ -293,8 +293,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -293,8 +293,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_greedy, output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
max_length=max_length, max_length=max_length,
output_scores=True, output_scores=True,
output_hidden_states=True, output_hidden_states=True,
...@@ -324,8 +324,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -324,8 +324,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# check `generate()` and `sample()` are equal # check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate( output_sample, output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
max_length=max_length, max_length=max_length,
num_return_sequences=3, num_return_sequences=3,
logits_processor=logits_processor, logits_processor=logits_processor,
...@@ -356,8 +356,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste ...@@ -356,8 +356,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
output_sample, output_generate = self._sample_generate( output_sample, output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
max_length=max_length, max_length=max_length,
num_return_sequences=1, num_return_sequences=1,
logits_processor=logits_processor, logits_processor=logits_processor,
...@@ -964,8 +964,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -964,8 +964,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_greedy, output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
output_scores=True, output_scores=True,
...@@ -989,8 +989,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -989,8 +989,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate( output_greedy, output_generate = self._greedy_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
output_scores=True, output_scores=True,
...@@ -1019,8 +1019,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1019,8 +1019,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# check `generate()` and `sample()` are equal # check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate( output_sample, output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
num_return_sequences=1, num_return_sequences=1,
...@@ -1050,8 +1050,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1050,8 +1050,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_sample, output_generate = self._sample_generate( output_sample, output_generate = self._sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids.to(torch_device),
attention_mask=attention_mask, attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
num_return_sequences=3, num_return_sequences=3,
...@@ -1089,8 +1089,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, ...@@ -1089,8 +1089,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).eval().to(torch_device) model = model_class(config).eval().to(torch_device)
if torch_device == "cuda": if torch_device == "cuda":
model.half() model.half()
model.generate(**input_dict, max_new_tokens=10) # greedy
model.generate(**input_dict, do_sample=True, max_new_tokens=10) model.generate(input_dict["input_ids"], attention_mask=input_dict["attention_mask"], max_new_tokens=10)
# sampling
model.generate(
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
)
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
...@@ -1230,8 +1234,8 @@ class MusicgenIntegrationTests(unittest.TestCase): ...@@ -1230,8 +1234,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off # fmt: off
EXPECTED_VALUES = torch.tensor( EXPECTED_VALUES = torch.tensor(
[ [
0.0765, 0.0758, 0.0749, 0.0759, 0.0759, 0.0771, 0.0775, 0.0760, -0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185,
0.0762, 0.0765, 0.0767, 0.0760, 0.0738, 0.0714, 0.0713, 0.0730, 0.0105, 0.0059, 0.0329, 0.0249, -0.0204, -0.0341, -0.0465, 0.0053,
] ]
) )
# fmt: on # fmt: on
...@@ -1312,8 +1316,8 @@ class MusicgenIntegrationTests(unittest.TestCase): ...@@ -1312,8 +1316,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off # fmt: off
EXPECTED_VALUES = torch.tensor( EXPECTED_VALUES = torch.tensor(
[ [
-0.0047, -0.0094, -0.0028, -0.0018, -0.0057, -0.0007, -0.0104, -0.0211, -0.0111, -0.0154, 0.0047, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229,
-0.0097, -0.0150, -0.0066, -0.0004, -0.0201, -0.0325, -0.0326, -0.0098, 0.0010, -0.0038, 0.0167, 0.0042, -0.0421, -0.0610, -0.0764, -0.0326,
] ]
) )
# fmt: on # fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment