speech_tts.py 2.07 KB
Newer Older
changhl's avatar
changhl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech
from transformers import SpeechT5HifiGan
from datasets import load_dataset
import torch
import os 
import numpy as np
import soundfile as sf
import argparse

def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('-hip', '--hip-device', type=int, default=0, help="initial hip devices")
    parser.add_argument('-m', '--model-path', type=str, default="", help="initial model path")
    parser.add_argument('-v', '--vocoder-path', type=str, default="", help="the vocoder model path")
    parser.add_argument('-t', '--text', type=str, default="Autumn, the season of change.", help="input text")
    parser.add_argument('-s', '--speaker', type=str, default="", help="the feature of speaker:path of xxx.npy")
    parser.add_argument('-res', '--result_path', type=str, default="../res", help="the path to save wav file")
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    return opt

def main(opt):    
    device = torch.device(f"cuda:{int(opt.hip_device)}")
    print(f"Using device: {device}")
    # 设置HF的下载路径为国内镜像
    os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

    # 初始化speechT5_tts模型
    processor = SpeechT5Processor.from_pretrained(opt.model_path)
    model = SpeechT5ForTextToSpeech.from_pretrained(opt.model_path).to(device)

    # input:encoder的文本输入
    inputs = processor(text=opt.text, return_tensors="pt").to(device)

    # input:decoder的发音人的语音特征输入(speaker embedding)
    speaker_embeddings = np.load(opt.speaker).astype(np.float64)
    speaker_embeddings = torch.tensor(speaker_embeddings, dtype=torch.float32).unsqueeze(0).to(device)

    # output:使用hifigan声码器将mfcc转为语音
    vocoder = SpeechT5HifiGan.from_pretrained(opt.vocoder_path).to(device)
    speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
    sf.write(os.path.join(opt.result_path, "tts.wav"), speech.cpu().numpy(), samplerate=16000)

if __name__ == "__main__":
    main(opt=parse_opt())