from dataclasses import asdict import torch import torchaudio from IPython.display import Audio, display from utils.audio import LogMelSpectrogram from config import ModelConfig, VocosConfig, MelConfig from models.model import StableTTS from vocos_pytorch.models.model import Vocos from text.mandarin import chinese_to_cnm3 from text.english import english_to_ipa2 from text.japanese import japanese_to_ipa2 from text import cleaned_text_to_sequence from text import symbols from datas.dataset import intersperse from scipy.io import wavfile import time device = 'cuda' if torch.cuda.is_available() else 'cpu' g2p_mapping = { 'chinese': chinese_to_cnm3, 'japanese': japanese_to_ipa2, 'english': english_to_ipa2, } @ torch.inference_mode() def inference(text: str, ref_audio: torch.Tensor, tts_model: StableTTS, mel_extractor: LogMelSpectrogram, vocoder: Vocos, phonemizer, sample_rate: int, step: int=10) -> torch.Tensor: x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0) x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device) waveform, sr = torchaudio.load(ref_audio) if sr != sample_rate: waveform = torchaudio.functional.resample(waveform, sr, sample_rate) y = mel_extractor(waveform).to(device) mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs'] audio = vocoder(mel) return audio.cpu(), mel.cpu() def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path: str, vocoder_checkpoint_path: str): tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config)) mel_extractor = LogMelSpectrogram(mel_config) vocoder = Vocos(vocoder_config, mel_config) tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu')) tts_model.to(device) tts_model.eval() vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu')) vocoder.to(device) vocoder.eval() return tts_model, mel_extractor, vocoder tts_model_config = ModelConfig() mel_config = MelConfig() vocoder_config = VocosConfig() tts_checkpoint_path = './checkpoints/checkpoint-zh_0.pt' vocoder_checkpoint_path = './checkpoints/vocoder.pt' tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path) total_params = sum(p.numel() for p in tts_model.parameters()) / 1e6 print(f'Total params: {total_params} M') language = 'chinese' # now we only support chinese, japanese and english phonemizer = g2p_mapping.get(language) text = '你好,世界!' ref_audio = './audio.wav' # start_time = time.time() output, mel = inference(text, ref_audio, tts_model, mel_extractor, vocoder, phonemizer, mel_config.sample_rate, 15) # print("infer time:", time.time() - start_time, "s") display(Audio(ref_audio)) display(Audio(output, rate=mel_config.sample_rate)) wavfile.write('generate.wav', mel_config.sample_rate, output.numpy().T)