inference.py 1.26 KB
Newer Older
changhl's avatar
changhl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torchaudio
from speechbrain.inference.TTS import Tacotron2
from speechbrain.inference.vocoders import HIFIGAN
import os
import argparse


def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model-path', type=str, default="", help="the tacotron2 model path")
    parser.add_argument('-v', '--vocoder-path', type=str, default="", help="the vocoder model path")
    parser.add_argument('-t', '--text', type=str, default="Autumn, the season of change.", help="input text")
    parser.add_argument('-res', '--result_path', type=str, default="./res", help="the path to save wav file")
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    return opt


def main(opt):    
    tacotron2 = Tacotron2.from_hparams(source=opt.model_path, run_opts={"device":"cuda"})
    hifi_gan = HIFIGAN.from_hparams(source=opt.vocoder_path,run_opts={"device":"cuda"})

    # Running the TTS
    mel_output, mel_length, alignment = tacotron2.encode_text(opt.text)

    # Running Vocoder (spectrogram-to-waveform)
    waveforms = hifi_gan.decode_batch(mel_output)

    # Save the waverform
    torchaudio.save(os.path.join(opt.result_path, 'example.wav'),waveforms.squeeze(1).cpu(), 22050)

if __name__ == "__main__":
    main(opt=parse_opt())