Commit c734f3ec authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

make style

parent a664e0ca
...@@ -95,9 +95,7 @@ with gr.Blocks(css=css) as block: ...@@ -95,9 +95,7 @@ with gr.Blocks(css=css) as block:
description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description") description = gr.Textbox(label="Description", lines=2, value="", elem_id="input_description")
run_button = gr.Button("Generate Audio", variant="primary") run_button = gr.Button("Generate Audio", variant="primary")
with gr.Column(): with gr.Column():
audio_out = gr.Audio( audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
label="Parler-TTS generation", type="numpy", elem_id="audio_out"
)
inputs = [input_text, description] inputs = [input_text, description]
outputs = [audio_out] outputs = [audio_out]
......
...@@ -22,7 +22,6 @@ if __name__ == "__main__": ...@@ -22,7 +22,6 @@ if __name__ == "__main__":
num_codebooks = encodec.num_codebooks num_codebooks = encodec.num_codebooks
print("num_codebooks", num_codebooks) print("num_codebooks", num_codebooks)
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size + 1, vocab_size=encodec_vocab_size + 1,
max_position_embeddings=2048, max_position_embeddings=2048,
...@@ -42,11 +41,9 @@ if __name__ == "__main__": ...@@ -42,11 +41,9 @@ if __name__ == "__main__":
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained(os.path.join(args.save_directory, "decoder")) decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
......
...@@ -41,7 +41,6 @@ if __name__ == "__main__": ...@@ -41,7 +41,6 @@ if __name__ == "__main__":
decoder.save_pretrained(os.path.join(args.save_directory, "decoder")) decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
......
...@@ -22,7 +22,6 @@ if __name__ == "__main__": ...@@ -22,7 +22,6 @@ if __name__ == "__main__":
num_codebooks = encodec.num_codebooks num_codebooks = encodec.num_codebooks
print("num_codebooks", num_codebooks) print("num_codebooks", num_codebooks)
decoder_config = ParlerTTSDecoderConfig( decoder_config = ParlerTTSDecoderConfig(
vocab_size=encodec_vocab_size + 64, # + 64 instead of +1 to have a multiple of 64 vocab_size=encodec_vocab_size + 64, # + 64 instead of +1 to have a multiple of 64
max_position_embeddings=4096, # 30 s = 2580 max_position_embeddings=4096, # 30 s = 2580
...@@ -42,11 +41,9 @@ if __name__ == "__main__": ...@@ -42,11 +41,9 @@ if __name__ == "__main__":
num_codebooks=num_codebooks, num_codebooks=num_codebooks,
) )
decoder = ParlerTTSForCausalLM(decoder_config) decoder = ParlerTTSForCausalLM(decoder_config)
decoder.save_pretrained(os.path.join(args.save_directory, "decoder")) decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
text_encoder_pretrained_model_name_or_path=text_model, text_encoder_pretrained_model_name_or_path=text_model,
audio_encoder_pretrained_model_name_or_path=encodec_version, audio_encoder_pretrained_model_name_or_path=encodec_version,
...@@ -64,5 +61,4 @@ if __name__ == "__main__": ...@@ -64,5 +61,4 @@ if __name__ == "__main__":
model.generation_config.do_sample = True # True model.generation_config.do_sample = True # True
model.generation_config.guidance_scale = 1 # 3.0 model.generation_config.guidance_scale = 1 # 3.0
model.save_pretrained(os.path.join(args.save_directory, "stable-speech-untrained-300M/"))
model.save_pretrained(os.path.join(args.save_directory,"stable-speech-untrained-300M/"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment