wavernn_inference_wrapper.py 7.93 KB
Newer Older
1
2
# *****************************************************************************
# Copyright (c) 2019 fatchord (https://github.com/fatchord)
3
#
4
5
6
7
8
9
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
10
#
11
12
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
13
#
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# *****************************************************************************


from torchaudio.models.wavernn import WaveRNN
import torch
import torchaudio
from torch import Tensor

29
from processing import normalized_waveform_to_bits
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class WaveRNNInferenceWrapper(torch.nn.Module):

    def __init__(self, wavernn: WaveRNN):
        super().__init__()
        self.wavernn_model = wavernn

    def _fold_with_overlap(self, x: Tensor, timesteps: int, overlap: int) -> Tensor:
        r'''Fold the tensor with overlap for quick batched inference.
        Overlap will be used for crossfading in xfade_and_unfold().

        x = [[h1, h2, ... hn]]
        Where each h is a vector of conditioning channels
        Eg: timesteps=2, overlap=1 with x.size(1)=10
        folded = [[h1, h2, h3, h4],
                  [h4, h5, h6, h7],
                  [h7, h8, h9, h10]]

        Args:
50
            x (tensor): Upsampled conditioning channels of size (1, timesteps, channel).
51
52
53
54
            timesteps (int): Timesteps for each index of batch.
            overlap (int): Timesteps for both xfade and rnn warmup.

        Return:
55
            folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        '''

        _, channels, total_len = x.size()

        # Calculate variables needed
        n_folds = (total_len - overlap) // (timesteps + overlap)
        extended_len = n_folds * (overlap + timesteps) + overlap
        remaining = total_len - extended_len

        # Pad if some time steps poking out
        if remaining != 0:
            n_folds += 1
            padding = timesteps + 2 * overlap - remaining
            x = self._pad_tensor(x, padding, side='after')

        folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)

        # Get the values for the folded tensor
        for i in range(n_folds):
            start = i * (timesteps + overlap)
            end = start + timesteps + 2 * overlap
            folded[i] = x[0, :, start:end]

        return folded

    def _xfade_and_unfold(self, y: Tensor, overlap: int) -> Tensor:
        r'''Applies a crossfade and unfolds into a 1d array.

        y = [[seq1],
             [seq2],
             [seq3]]
        Apply a gain envelope at both ends of the sequences
        y = [[seq1_in, seq1_timesteps, seq1_out],
             [seq2_in, seq2_timesteps, seq2_out],
             [seq3_in, seq3_timesteps, seq3_out]]
        Stagger and add up the groups of samples:
            [seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]

        Args:
95
96
            y (Tensor): Batched sequences of audio samples of size
                (num_folds, channels, timesteps + 2 * overlap).
97
98
99
            overlap (int): Timesteps for both xfade and rnn warmup.

        Returns:
100
            unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
101
102
        '''

103
        num_folds, channels, length = y.shape
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        timesteps = length - 2 * overlap
        total_len = num_folds * (timesteps + overlap) + overlap

        # Need some silence for the rnn warmup
        silence_len = overlap // 2
        fade_len = overlap - silence_len
        silence = torch.zeros((silence_len), dtype=y.dtype, device=y.device)
        linear = torch.ones((silence_len), dtype=y.dtype, device=y.device)

        # Equal power crossfade
        t = torch.linspace(-1, 1, fade_len, dtype=y.dtype, device=y.device)
        fade_in = torch.sqrt(0.5 * (1 + t))
        fade_out = torch.sqrt(0.5 * (1 - t))

        # Concat the silence to the fades
        fade_in = torch.cat([silence, fade_in])
        fade_out = torch.cat([linear, fade_out])

        # Apply the gain to the overlap samples
123
124
        y[:, :, :overlap] *= fade_in
        y[:, :, -overlap:] *= fade_out
125

126
        unfolded = torch.zeros((channels, total_len), dtype=y.dtype, device=y.device)
127
128
129
130
131

        # Loop to add up all the samples
        for i in range(num_folds):
            start = i * (timesteps + overlap)
            end = start + timesteps + 2 * overlap
132
            unfolded[:, start:end] += y[i]
133
134
135
136
137
138
139

        return unfolded

    def _pad_tensor(self, x: Tensor, pad: int, side: str = 'both') -> Tensor:
        r"""Pad the given tensor.

        Args:
140
            x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
141
142
143
            pad (int): The amount of padding applied to the input.

        Return:
144
            padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
145
146
147
148
149
150
151
152
153
154
        """
        b, c, t = x.size()
        total = t + 2 * pad if side == 'both' else t + pad
        padded = torch.zeros((b, c, total), device=x.device)
        if side == 'before' or side == 'both':
            padded[:, :, pad:pad + t] = x
        elif side == 'after':
            padded[:, :, :t] = x
        else:
            raise ValueError(f"Unexpected side: '{side}'. "
155
                             f"Valid choices are 'both', 'before' and 'after'.")
156
157
        return padded

158
159
160
161
    def forward(self,
                specgram: Tensor,
                mulaw: bool = True,
                batched: bool = True,
162
163
                timesteps: int = 100,
                overlap: int = 5) -> Tensor:
164
165
166
167
168
        r"""Inference function for WaveRNN.

        Based on the implementation from
        https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py.

169
170
171

        Currently only supports multinomial sampling.

172
        Args:
173
            specgram (Tensor): spectrogram of size (n_mels, n_time)
174
175
176
177
            mulaw (bool): Whether to perform mulaw decoding (Default: ``True``).
            batched (bool): Whether to perform batch prediction. Using batch prediction
                will significantly increase the inference speed (Default: ``True``).
            timesteps (int): The time steps for each batch. Only used when `batched`
178
                is set to True (Default: ``100``).
179
            overlap (int): The overlapping time steps between batches. Only used when `batched`
180
                is set to True (Default: ``5``).
181
182

        Returns:
183
184
            waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
                1 represents single channel.
185
186
187
188
189
190
191
192
193
194
        """
        pad = (self.wavernn_model.kernel_size - 1) // 2

        specgram = specgram.unsqueeze(0)
        specgram = self._pad_tensor(specgram, pad=pad, side='both')
        if batched:
            specgram = self._fold_with_overlap(specgram, timesteps, overlap)

        n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes))

195
        output = self.wavernn_model.infer(specgram).cpu()
196
197
198
199
200
201
202
203
204
205
206

        if mulaw:
            output = normalized_waveform_to_bits(output, n_bits)
            output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes)

        if batched:
            output = self._xfade_and_unfold(output, overlap)
        else:
            output = output[0]

        return output