{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from dataclasses import asdict\n", "import torch\n", "import torchaudio\n", "from IPython.display import Audio, display\n", "\n", "from utils.audio import LogMelSpectrogram\n", "from config import ModelConfig, VocosConfig, MelConfig\n", "from models.model import StableTTS\n", "from vocos_pytorch.models.model import Vocos\n", "from text.mandarin import chinese_to_cnm3\n", "from text.english import english_to_ipa2\n", "from text.japanese import japanese_to_ipa2\n", "from text import cleaned_text_to_sequence\n", "from text import symbols\n", "from datas.dataset import intersperse\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "g2p_mapping = {\n", " 'chinese': chinese_to_cnm3,\n", " 'japanese': japanese_to_ipa2,\n", " 'english': english_to_ipa2,\n", "}\n", "\n", "@ torch.inference_mode()\n", "def inference(text: str, ref_audio: torch.Tensor, tts_model: StableTTS, mel_extractor: LogMelSpectrogram, vocoder: Vocos, phonemizer, sample_rate: int, step: int=10) -> torch.Tensor:\n", " x = torch.tensor(intersperse(cleaned_text_to_sequence(phonemizer(text)), item=0), dtype=torch.long, device=device).unsqueeze(0)\n", " x_len = torch.tensor([x.size(-1)], dtype=torch.long, device=device)\n", " waveform, sr = torchaudio.load(ref_audio)\n", " if sr != sample_rate:\n", " waveform = torchaudio.functional.resample(waveform, sr, sample_rate)\n", " y = mel_extractor(waveform).to(device)\n", " mel = tts_model.synthesise(x, x_len, step, y=y, temperature=0.667, length_scale=1)['decoder_outputs']\n", " audio = vocoder(mel)\n", " return audio.cpu(), mel.cpu()\n", "\n", "def get_pipeline(n_vocab: int, tts_model_config: ModelConfig, mel_config: MelConfig, vocoder_config: VocosConfig, tts_checkpoint_path: str, vocoder_checkpoint_path: str):\n", " tts_model = StableTTS(n_vocab, mel_config.n_mels, **asdict(tts_model_config))\n", " mel_extractor = LogMelSpectrogram(mel_config)\n", " vocoder = Vocos(vocoder_config, mel_config)\n", " tts_model.load_state_dict(torch.load(tts_checkpoint_path, map_location='cpu'))\n", " tts_model.to(device)\n", " tts_model.eval()\n", " vocoder.load_state_dict(torch.load(vocoder_checkpoint_path, map_location='cpu'))\n", " vocoder.to(device)\n", " vocoder.eval()\n", " return tts_model, mel_extractor, vocoder\n", "\n", "tts_model_config = ModelConfig()\n", "mel_config = MelConfig()\n", "vocoder_config = VocosConfig()\n", "\n", "tts_checkpoint_path = './checkpoints/checkpoint-zh_0.pt'\n", "vocoder_checkpoint_path = './checkpoints/vocoder.pt'\n", "\n", "tts_model, mel_extractor, vocoder = get_pipeline(len(symbols), tts_model_config, mel_config, vocoder_config, tts_checkpoint_path, vocoder_checkpoint_path)\n", "total_params = sum(p.numel() for p in tts_model.parameters()) / 1e6\n", "print(f'Total params: {total_params} M')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "language = 'chinese' # now we only support chinese, japanese and english\n", "\n", "phonemizer = g2p_mapping.get(language)\n", "\n", "text = '你好,世界!'\n", "ref_audio = './audio.wav'\n", "output, mel = inference(text, ref_audio, tts_model, mel_extractor, vocoder, phonemizer, mel_config.sample_rate, 15)\n", "display(Audio(ref_audio))\n", "display(Audio(output, rate=mel_config.sample_rate))" ] } ], "metadata": { "kernelspec": { "display_name": "lxn_vits", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }