Unverified Commit 03f98f96 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[MusicGen] Fix integration tests (#25169)

* move to device

* update with cuda values

* fix fp16

* more rigorous
parent c90e14fb
......@@ -773,10 +773,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = torch.zeros((bsz, seq_len, self.d_model), device=input_ids.device)
for codebook in range(num_codebooks):
inputs_embeds += self.embed_tokens[codebook](input[:, codebook])
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......
......@@ -267,8 +267,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
......@@ -293,8 +293,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
output_scores=True,
output_hidden_states=True,
......@@ -324,8 +324,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=3,
logits_processor=logits_processor,
......@@ -356,8 +356,8 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
max_length=max_length,
num_return_sequences=1,
logits_processor=logits_processor,
......@@ -964,8 +964,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
......@@ -989,8 +989,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).to(torch_device).eval()
output_greedy, output_generate = self._greedy_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
output_scores=True,
......@@ -1019,8 +1019,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# check `generate()` and `sample()` are equal
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
num_return_sequences=1,
......@@ -1050,8 +1050,8 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
output_sample, output_generate = self._sample_generate(
model=model,
input_ids=input_ids,
attention_mask=attention_mask,
input_ids=input_ids.to(torch_device),
attention_mask=attention_mask.to(torch_device),
decoder_input_ids=decoder_input_ids,
max_length=max_length,
num_return_sequences=3,
......@@ -1089,8 +1089,12 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
model = model_class(config).eval().to(torch_device)
if torch_device == "cuda":
model.half()
model.generate(**input_dict, max_new_tokens=10)
model.generate(**input_dict, do_sample=True, max_new_tokens=10)
# greedy
model.generate(input_dict["input_ids"], attention_mask=input_dict["attention_mask"], max_new_tokens=10)
# sampling
model.generate(
input_dict["input_ids"], attention_mask=input_dict["attention_mask"], do_sample=True, max_new_tokens=10
)
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
......@@ -1230,8 +1234,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
0.0765, 0.0758, 0.0749, 0.0759, 0.0759, 0.0771, 0.0775, 0.0760,
0.0762, 0.0765, 0.0767, 0.0760, 0.0738, 0.0714, 0.0713, 0.0730,
-0.0099, -0.0140, 0.0079, 0.0080, -0.0046, 0.0065, -0.0068, -0.0185,
0.0105, 0.0059, 0.0329, 0.0249, -0.0204, -0.0341, -0.0465, 0.0053,
]
)
# fmt: on
......@@ -1312,8 +1316,8 @@ class MusicgenIntegrationTests(unittest.TestCase):
# fmt: off
EXPECTED_VALUES = torch.tensor(
[
-0.0047, -0.0094, -0.0028, -0.0018, -0.0057, -0.0007, -0.0104, -0.0211,
-0.0097, -0.0150, -0.0066, -0.0004, -0.0201, -0.0325, -0.0326, -0.0098,
-0.0111, -0.0154, 0.0047, 0.0058, -0.0068, 0.0012, -0.0109, -0.0229,
0.0010, -0.0038, 0.0167, 0.0042, -0.0421, -0.0610, -0.0764, -0.0326,
]
)
# fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment