import argparse import torch import torch.nn.functional as F import torchaudio from torchaudio.transforms import MelSpectrogram from torchaudio.models import wavernn from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS from torchaudio.datasets import LJSPEECH from wavernn_inference_wrapper import WaveRNNInferenceWrapper from processing import NormalizeDB def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--output-wav-path", default="./output.wav", type=str, metavar="PATH", help="The path to output the reconstructed wav file.", ) parser.add_argument( "--jit", default=False, action="store_true", help="If used, the model and inference function is jitted." ) parser.add_argument( "--loss", default="crossentropy", choices=["crossentropy"], type=str, help="The type of loss the pretrained model is trained on.", ) parser.add_argument( "--no-batch-inference", default=False, action="store_true", help="Don't use batch inference." ) parser.add_argument( "--no-mulaw", default=False, action="store_true", help="Don't use mulaw decoder to decoder the signal." ) parser.add_argument( "--checkpoint-name", default="wavernn_10k_epochs_8bits_ljspeech", choices=list(_MODEL_CONFIG_AND_URLS.keys()), help="Select the WaveRNN checkpoint." ) parser.add_argument( "--batch-timesteps", default=11000, type=int, help="The time steps for each batch. Only used when batch inference is used", ) parser.add_argument( "--batch-overlap", default=550, type=int, help="The overlapping time steps between batches. Only used when batch inference is used", ) args = parser.parse_args() return args def main(args): device = "cuda" if torch.cuda.is_available() else "cpu" waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0] mel_kwargs = { 'sample_rate': sample_rate, 'n_fft': 2048, 'f_min': 40., 'n_mels': 80, 'win_length': 1100, 'hop_length': 275, 'mel_scale': 'slaney', 'norm': 'slaney', 'power': 1, } transforms = torch.nn.Sequential( MelSpectrogram(**mel_kwargs), NormalizeDB(min_level_db=-100, normalization=True), ) mel_specgram = transforms(waveform) wavernn_model = wavernn(args.checkpoint_name).eval().to(device) wavernn_model = WaveRNNInferenceWrapper(wavernn_model) if args.jit: wavernn_model = torch.jit.script(wavernn_model) with torch.no_grad(): output = wavernn_model.infer(mel_specgram.to(device), loss_name=args.loss, mulaw=(not args.no_mulaw), batched=(not args.no_batch_inference), timesteps=args.batch_timesteps, overlap=args.batch_overlap,) torchaudio.save(args.output_wav_path, output.reshape(1, -1), sample_rate=sample_rate) if __name__ == "__main__": args = parse_args() main(args)