inference.py 3.06 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from dataclasses import asdict
import torch
import torchaudio
from IPython.display import Audio, display

from utils.audio import LogMelSpectrogram
from config import ModelConfig, VocosConfig, MelConfig
from models.model import StableTTS
from vocos_pytorch.models.model import Vocos
from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2
from text import cleaned_text_to_sequence
from text import symbols
from datas.dataset import intersperse
from scipy.io import wavfile
import time

device = 'cuda' if torch.cuda.is_available() else 'cpu'

g2p_mapping = {
    'chinese': chinese_to_cnm3,
    'japanese': japanese_to_ipa2,
    'english': english_to_ipa2,
}

@ torch.inference_mode()
def inference(text: str, ref_audio: torch.Tensor, tts_model: StableTTS, mel_extractor: LogMelSpectrogram, vocoder: Vocos, phonemizer, sample_rate: int, step: int=10) -> torch.Tensor:
    x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)
    x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)
    waveform, sr = torchaudio.load(ref_audio)
    if sr != sample_rate:
        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
    y = mel_extractor(waveform).to(device)
    mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']
    audio = vocoder(mel)
    return audio.cpu(), mel.cpu()

def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path: str, vocoder_checkpoint_path: str):
    tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))
    mel_extractor = LogMelSpectrogram(mel_config)
    vocoder = Vocos(vocoder_config, mel_config)
    tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))
    tts_model.to(device)
    tts_model.eval()
    vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))
    vocoder.to(device)
    vocoder.eval()
    return tts_model, mel_extractor, vocoder

tts_model_config = ModelConfig()
mel_config = MelConfig()
vocoder_config = VocosConfig()

tts_checkpoint_path = './checkpoints/checkpoint-zh_0.pt'
vocoder_checkpoint_path = './checkpoints/vocoder.pt'

tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)
total_params = sum(p.numel() for p in tts_model.parameters()) / 1e6
print(f'Total params: {total_params} M')



language = 'chinese' # now we only support chinese, japanese and english

phonemizer = g2p_mapping.get(language)

text = '你好,世界!'
ref_audio = './audio.wav'
# start_time = time.time()
output, mel = inference(text, ref_audio, tts_model, mel_extractor, vocoder, phonemizer, mel_config.sample_rate, 15)
# print("infer time:", time.time() - start_time, "s")
display(Audio(ref_audio))
display(Audio(output, rate=mel_config.sample_rate))
wavfile.write('generate.wav', mel_config.sample_rate, output.numpy().T)