Unverified Commit 1551b7c5 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

add possibility to have audio_output_lengths (#91)

parent 862f8418
......@@ -3511,6 +3511,7 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_ids,
audio_scales=audio_scales,
).audio_values.squeeze(1)
output_lengths = [audio.shape[0] for audio in output_values]
else:
output_values = []
for sample_id in range(batch_size):
......@@ -3522,13 +3523,14 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1, 1, 1)).to(self.device))
# TODO: we should keep track of output length as well. Not really straightforward tbh
output_lengths = [audio.shape[0] for audio in output_values]
output_values = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
if generation_config.return_dict_in_generate:
outputs["audios_length"] = output_lengths
outputs.sequences = output_values
return outputs
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment