inference.py 2.82 KB
Newer Older
1
2
3
4
import argparse

import torch
import torchaudio
5
6
from processing import NormalizeDB
from torchaudio.datasets import LJSPEECH
7
8
from torchaudio.models import wavernn
from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS
9
from torchaudio.transforms import MelSpectrogram
10
11
12
13
14
15
from wavernn_inference_wrapper import WaveRNNInferenceWrapper


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
16
17
18
19
        "--output-wav-path",
        default="./output.wav",
        type=str,
        metavar="PATH",
20
21
22
        help="The path to output the reconstructed wav file.",
    )
    parser.add_argument(
23
        "--jit", default=False, action="store_true", help="If used, the model and inference function is jitted."
24
    )
25
    parser.add_argument("--no-batch-inference", default=False, action="store_true", help="Don't use batch inference.")
26
    parser.add_argument(
27
        "--no-mulaw", default=False, action="store_true", help="Don't use mulaw decoder to decoder the signal."
28
29
    )
    parser.add_argument(
30
31
        "--checkpoint-name",
        default="wavernn_10k_epochs_8bits_ljspeech",
32
        choices=list(_MODEL_CONFIG_AND_URLS.keys()),
33
        help="Select the WaveRNN checkpoint.",
34
35
    )
    parser.add_argument(
36
37
38
        "--batch-timesteps",
        default=100,
        type=int,
39
40
41
        help="The time steps for each batch. Only used when batch inference is used",
    )
    parser.add_argument(
42
43
44
        "--batch-overlap",
        default=5,
        type=int,
45
46
47
48
49
50
51
52
53
54
55
        help="The overlapping time steps between batches. Only used when batch inference is used",
    )
    args = parser.parse_args()
    return args


def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    waveform, sample_rate, _, _ = LJSPEECH("./", download=True)[0]

    mel_kwargs = {
56
57
58
59
60
61
62
63
64
        "sample_rate": sample_rate,
        "n_fft": 2048,
        "f_min": 40.0,
        "n_mels": 80,
        "win_length": 1100,
        "hop_length": 275,
        "mel_scale": "slaney",
        "norm": "slaney",
        "power": 1,
65
66
67
68
69
70
71
72
    }
    transforms = torch.nn.Sequential(
        MelSpectrogram(**mel_kwargs),
        NormalizeDB(min_level_db=-100, normalization=True),
    )
    mel_specgram = transforms(waveform)

    wavernn_model = wavernn(args.checkpoint_name).eval().to(device)
73
    wavernn_inference_model = WaveRNNInferenceWrapper(wavernn_model)
74
75

    if args.jit:
76
        wavernn_inference_model = torch.jit.script(wavernn_inference_model)
77
78

    with torch.no_grad():
79
80
81
82
83
84
85
        output = wavernn_inference_model(
            mel_specgram.to(device),
            mulaw=(not args.no_mulaw),
            batched=(not args.no_batch_inference),
            timesteps=args.batch_timesteps,
            overlap=args.batch_overlap,
        )
86

87
    torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
88
89
90
91
92


if __name__ == "__main__":
    args = parse_args()
    main(args)