wavernn.py 15.2 KB
Newer Older
1
from typing import List, Tuple, Optional
moto's avatar
moto committed
2
import math
jimchen90's avatar
jimchen90 committed
3

jimchen90's avatar
jimchen90 committed
4
import torch
jimchen90's avatar
jimchen90 committed
5
6
from torch import Tensor
from torch import nn
7
import torch.nn.functional as F
jimchen90's avatar
jimchen90 committed
8

9
10
11
12
13
14
15
__all__ = [
    "ResBlock",
    "MelResNet",
    "Stretch2d",
    "UpsampleNetwork",
    "WaveRNN",
]
jimchen90's avatar
jimchen90 committed
16
17


18
class ResBlock(nn.Module):
19
    r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
jimchen90's avatar
jimchen90 committed
20
21

    Args:
22
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
jimchen90's avatar
jimchen90 committed
23

jimchen90's avatar
jimchen90 committed
24
    Examples
25
        >>> resblock = ResBlock()
jimchen90's avatar
jimchen90 committed
26
27
        >>> input = torch.rand(10, 128, 512)  # a random spectrogram
        >>> output = resblock(input)  # shape: (10, 128, 512)
jimchen90's avatar
jimchen90 committed
28
29
    """

jimchen90's avatar
jimchen90 committed
30
    def __init__(self, n_freq: int = 128) -> None:
jimchen90's avatar
jimchen90 committed
31
32
33
        super().__init__()

        self.resblock_model = nn.Sequential(
jimchen90's avatar
jimchen90 committed
34
35
            nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
            nn.BatchNorm1d(n_freq),
jimchen90's avatar
jimchen90 committed
36
            nn.ReLU(inplace=True),
jimchen90's avatar
jimchen90 committed
37
38
            nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
            nn.BatchNorm1d(n_freq)
jimchen90's avatar
jimchen90 committed
39
40
        )

jimchen90's avatar
jimchen90 committed
41
    def forward(self, specgram: Tensor) -> Tensor:
42
        r"""Pass the input through the ResBlock layer.
jimchen90's avatar
jimchen90 committed
43
        Args:
44
            specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
jimchen90's avatar
jimchen90 committed
45

jimchen90's avatar
jimchen90 committed
46
47
        Return:
            Tensor shape: (n_batch, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
48
49
        """

jimchen90's avatar
jimchen90 committed
50
        return self.resblock_model(specgram) + specgram
jimchen90's avatar
jimchen90 committed
51
52


53
class MelResNet(nn.Module):
jimchen90's avatar
jimchen90 committed
54
    r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
jimchen90's avatar
jimchen90 committed
55
56

    Args:
57
58
59
60
61
        n_res_block: the number of ResBlock in stack. (Default: ``10``)
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
        n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
        n_output: the number of output dimensions of melresnet. (Default: ``128``)
        kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
jimchen90's avatar
jimchen90 committed
62
63

    Examples
64
        >>> melresnet = MelResNet()
jimchen90's avatar
jimchen90 committed
65
66
        >>> input = torch.rand(10, 128, 512)  # a random spectrogram
        >>> output = melresnet(input)  # shape: (10, 128, 508)
jimchen90's avatar
jimchen90 committed
67
68
    """

jimchen90's avatar
jimchen90 committed
69
70
71
72
73
74
    def __init__(self,
                 n_res_block: int = 10,
                 n_freq: int = 128,
                 n_hidden: int = 128,
                 n_output: int = 128,
                 kernel_size: int = 5) -> None:
jimchen90's avatar
jimchen90 committed
75
76
        super().__init__()

77
        ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
jimchen90's avatar
jimchen90 committed
78
79

        self.melresnet_model = nn.Sequential(
jimchen90's avatar
jimchen90 committed
80
81
            nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
            nn.BatchNorm1d(n_hidden),
jimchen90's avatar
jimchen90 committed
82
83
            nn.ReLU(inplace=True),
            *ResBlocks,
jimchen90's avatar
jimchen90 committed
84
            nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1)
jimchen90's avatar
jimchen90 committed
85
86
        )

jimchen90's avatar
jimchen90 committed
87
    def forward(self, specgram: Tensor) -> Tensor:
88
        r"""Pass the input through the MelResNet layer.
jimchen90's avatar
jimchen90 committed
89
        Args:
90
            specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
jimchen90's avatar
jimchen90 committed
91

jimchen90's avatar
jimchen90 committed
92
93
        Return:
            Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
jimchen90's avatar
jimchen90 committed
94
95
        """

jimchen90's avatar
jimchen90 committed
96
        return self.melresnet_model(specgram)
jimchen90's avatar
jimchen90 committed
97
98


99
class Stretch2d(nn.Module):
jimchen90's avatar
jimchen90 committed
100
101
102
103
104
105
106
    r"""Upscale the frequency and time dimensions of a spectrogram.

    Args:
        time_scale: the scale factor in time dimension
        freq_scale: the scale factor in frequency dimension

    Examples
107
        >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
jimchen90's avatar
jimchen90 committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121

        >>> input = torch.rand(10, 100, 512)  # a random spectrogram
        >>> output = stretch2d(input)  # shape: (10, 500, 5120)
    """

    def __init__(self,
                 time_scale: int,
                 freq_scale: int) -> None:
        super().__init__()

        self.freq_scale = freq_scale
        self.time_scale = time_scale

    def forward(self, specgram: Tensor) -> Tensor:
122
        r"""Pass the input through the Stretch2d layer.
jimchen90's avatar
jimchen90 committed
123
124

        Args:
125
            specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
jimchen90's avatar
jimchen90 committed
126
127
128
129
130
131
132
133

        Return:
            Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
        """

        return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)


134
class UpsampleNetwork(nn.Module):
jimchen90's avatar
jimchen90 committed
135
136
137
    r"""Upscale the dimensions of a spectrogram.

    Args:
138
139
140
141
142
143
        upsample_scales: the list of upsample scales.
        n_res_block: the number of ResBlock in stack. (Default: ``10``)
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
        n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
        n_output: the number of output dimensions of melresnet. (Default: ``128``)
        kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
jimchen90's avatar
jimchen90 committed
144
145

    Examples
146
        >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
jimchen90's avatar
jimchen90 committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        >>> input = torch.rand(10, 128, 10)  # a random spectrogram
        >>> output = upsamplenetwork(input)  # shape: (10, 1536, 128), (10, 1536, 128)
    """

    def __init__(self,
                 upsample_scales: List[int],
                 n_res_block: int = 10,
                 n_freq: int = 128,
                 n_hidden: int = 128,
                 n_output: int = 128,
                 kernel_size: int = 5) -> None:
        super().__init__()

        total_scale = 1
        for upsample_scale in upsample_scales:
            total_scale *= upsample_scale
163
        self.total_scale: int = total_scale
jimchen90's avatar
jimchen90 committed
164
165

        self.indent = (kernel_size - 1) // 2 * total_scale
166
167
        self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
        self.resnet_stretch = Stretch2d(total_scale, 1)
jimchen90's avatar
jimchen90 committed
168
169
170

        up_layers = []
        for scale in upsample_scales:
171
            stretch = Stretch2d(scale, 1)
jimchen90's avatar
jimchen90 committed
172
173
174
175
176
177
178
179
180
181
            conv = nn.Conv2d(in_channels=1,
                             out_channels=1,
                             kernel_size=(1, scale * 2 + 1),
                             padding=(0, scale),
                             bias=False)
            conv.weight.data.fill_(1. / (scale * 2 + 1))
            up_layers.append(stretch)
            up_layers.append(conv)
        self.upsample_layers = nn.Sequential(*up_layers)

182
    def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
183
        r"""Pass the input through the UpsampleNetwork layer.
jimchen90's avatar
jimchen90 committed
184
185

        Args:
186
            specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

        Return:
            Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
                          (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
        where total_scale is the product of all elements in upsample_scales.
        """

        resnet_output = self.resnet(specgram).unsqueeze(1)
        resnet_output = self.resnet_stretch(resnet_output)
        resnet_output = resnet_output.squeeze(1)

        specgram = specgram.unsqueeze(1)
        upsampling_output = self.upsample_layers(specgram)
        upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent]

        return upsampling_output, resnet_output
jimchen90's avatar
jimchen90 committed
203
204


205
class WaveRNN(nn.Module):
jimchen90's avatar
jimchen90 committed
206
207
    r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.

208
209
210
    The original implementation was introduced in *Efficient Neural Audio Synthesis*
    [:footcite:`kalchbrenner2018efficient`]. The input channels of waveform and spectrogram have to be 1.
    The product of `upsample_scales` must equal `hop_length`.
jimchen90's avatar
jimchen90 committed
211
212

    Args:
213
214
215
216
217
218
219
220
221
222
        upsample_scales: the list of upsample scales.
        n_classes: the number of output classes.
        hop_length: the number of samples between the starts of consecutive frames.
        n_res_block: the number of ResBlock in stack. (Default: ``10``)
        n_rnn: the dimension of RNN layer. (Default: ``512``)
        n_fc: the dimension of fully connected layer. (Default: ``512``)
        kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
        n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
        n_output: the number of output dimensions of melresnet. (Default: ``128``)
jimchen90's avatar
jimchen90 committed
223
224

    Example
225
        >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
jimchen90's avatar
jimchen90 committed
226
227
228
229
        >>> waveform, sample_rate = torchaudio.load(file)
        >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
        >>> specgram = MelSpectrogram(sample_rate)(waveform)  # shape: (n_batch, n_channel, n_freq, n_time)
        >>> output = wavernn(waveform, specgram)
230
        >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
jimchen90's avatar
jimchen90 committed
231
232
233
234
    """

    def __init__(self,
                 upsample_scales: List[int],
235
                 n_classes: int,
jimchen90's avatar
jimchen90 committed
236
237
238
239
240
241
242
                 hop_length: int,
                 n_res_block: int = 10,
                 n_rnn: int = 512,
                 n_fc: int = 512,
                 kernel_size: int = 5,
                 n_freq: int = 128,
                 n_hidden: int = 128,
243
                 n_output: int = 128) -> None:
jimchen90's avatar
jimchen90 committed
244
245
246
        super().__init__()

        self.kernel_size = kernel_size
247
        self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
jimchen90's avatar
jimchen90 committed
248
249
250
        self.n_rnn = n_rnn
        self.n_aux = n_output // 4
        self.hop_length = hop_length
251
        self.n_classes = n_classes
moto's avatar
moto committed
252
        self.n_bits: int = int(math.log2(self.n_classes))
jimchen90's avatar
jimchen90 committed
253
254
255
256
257
258
259

        total_scale = 1
        for upsample_scale in upsample_scales:
            total_scale *= upsample_scale
        if total_scale != self.hop_length:
            raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")

260
261
262
263
264
265
        self.upsample = UpsampleNetwork(upsample_scales,
                                        n_res_block,
                                        n_freq,
                                        n_hidden,
                                        n_output,
                                        kernel_size)
jimchen90's avatar
jimchen90 committed
266
267
268
269
270
271
272
273
274
275
276
277
278
        self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)

        self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
        self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)

        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)

        self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
        self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
        self.fc3 = nn.Linear(n_fc, self.n_classes)

    def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
279
        r"""Pass the input through the WaveRNN model.
jimchen90's avatar
jimchen90 committed
280
281

        Args:
282
283
            waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
            specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
284
285

        Return:
286
            Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
jimchen90's avatar
jimchen90 committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        """

        assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
        assert specgram.size(1) == 1, 'Require the input channel of specgram is 1'
        # remove channel dimension until the end
        waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)

        batch_size = waveform.size(0)
        h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
        h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
        # output of upsample:
        # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
        # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
        specgram, aux = self.upsample(specgram)
        specgram = specgram.transpose(1, 2)
        aux = aux.transpose(1, 2)

        aux_idx = [self.n_aux * i for i in range(5)]
        a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
        a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
        a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
        a4 = aux[:, :, aux_idx[3]:aux_idx[4]]

        x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
        x = self.fc(x)
        res = x
        x, _ = self.rnn1(x, h1)

        x = x + res
        res = x
        x = torch.cat([x, a2], dim=-1)
        x, _ = self.rnn2(x, h2)

        x = x + res
        x = torch.cat([x, a3], dim=-1)
        x = self.fc1(x)
        x = self.relu1(x)

        x = torch.cat([x, a4], dim=-1)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)

        # bring back channel dimension
        return x.unsqueeze(1)
332

333
    @torch.jit.export
334
    def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
335
336
337
338
339
340
        r"""Inference method of WaveRNN.

        This function currently only supports multinomial sampling, which assumes the
        network is trained on cross entropy loss.

        Args:
341
342
343
            specgram (Tensor):
                Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
            lengths (Tensor or None, optional):
344
345
346
347
348
349
350
                Indicates the valid length of each audio in the batch.
                Shape: `(batch, )`.
                When the ``specgram`` contains spectrograms with different duration,
                by providing ``lengths`` argument, the model will compute
                the corresponding valid output lengths.
                If ``None``, it is assumed that all the audio in ``waveforms``
                have valid length. Default: ``None``.
351
352

        Returns:
353
            (Tensor, Optional[Tensor]):
354
355
            Tensor
                The inferred waveform of size `(n_batch, 1, n_time)`.
356
                1 stands for a single channel.
357
            Tensor or None
358
359
360
                If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
                is retuned.
                It indicates the valid length in time axis of the output Tensor.
361
362
363
364
365
        """

        device = specgram.device
        dtype = specgram.dtype

366
        specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
367
        specgram, aux = self.upsample(specgram)
368
369
        if lengths is not None:
            lengths = lengths * self.upsample.total_scale
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

        output: List[Tensor] = []
        b_size, _, seq_len = specgram.size()

        h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
        h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
        x = torch.zeros((b_size, 1), device=device, dtype=dtype)

        aux_split = [aux[:, self.n_aux * i: self.n_aux * (i + 1), :] for i in range(4)]

        for i in range(seq_len):

            m_t = specgram[:, :, i]

            a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]

            x = torch.cat([x, m_t, a1_t], dim=1)
            x = self.fc(x)
            _, h1 = self.rnn1(x.unsqueeze(1), h1)

            x = x + h1[0]
            inp = torch.cat([x, a2_t], dim=1)
            _, h2 = self.rnn2(inp.unsqueeze(1), h2)

            x = x + h2[0]
            x = torch.cat([x, a3_t], dim=1)
            x = F.relu(self.fc1(x))

            x = torch.cat([x, a4_t], dim=1)
            x = F.relu(self.fc2(x))

            logits = self.fc3(x)

            posterior = F.softmax(logits, dim=1)

            x = torch.multinomial(posterior, 1).float()
            # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
moto's avatar
moto committed
407
            x = 2 * x / (2 ** self.n_bits - 1.0) - 1.0
408
409
410

            output.append(x)

411
        return torch.stack(output).permute(1, 2, 0), lengths