wavernn.py 12.1 KB
Newer Older
1
from typing import List, Tuple
jimchen90's avatar
jimchen90 committed
2

jimchen90's avatar
jimchen90 committed
3
import torch
jimchen90's avatar
jimchen90 committed
4
5
6
from torch import Tensor
from torch import nn

7
__all__ = ["ResBlock", "MelResNet", "Stretch2d", "UpsampleNetwork", "WaveRNN"]
jimchen90's avatar
jimchen90 committed
8
9


10
class ResBlock(nn.Module):
jimchen90's avatar
jimchen90 committed
11
12
13
    r"""ResNet block based on "Deep Residual Learning for Image Recognition"

    The paper link is https://arxiv.org/pdf/1512.03385.pdf.
jimchen90's avatar
jimchen90 committed
14
15

    Args:
16
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
jimchen90's avatar
jimchen90 committed
17

jimchen90's avatar
jimchen90 committed
18
    Examples
19
        >>> resblock = ResBlock()
jimchen90's avatar
jimchen90 committed
20
21
        >>> input = torch.rand(10, 128, 512)  # a random spectrogram
        >>> output = resblock(input)  # shape: (10, 128, 512)
jimchen90's avatar
jimchen90 committed
22
23
    """

jimchen90's avatar
jimchen90 committed
24
    def __init__(self, n_freq: int = 128) -> None:
jimchen90's avatar
jimchen90 committed
25
26
27
        super().__init__()

        self.resblock_model = nn.Sequential(
jimchen90's avatar
jimchen90 committed
28
29
            nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
            nn.BatchNorm1d(n_freq),
jimchen90's avatar
jimchen90 committed
30
            nn.ReLU(inplace=True),
jimchen90's avatar
jimchen90 committed
31
32
            nn.Conv1d(in_channels=n_freq, out_channels=n_freq, kernel_size=1, bias=False),
            nn.BatchNorm1d(n_freq)
jimchen90's avatar
jimchen90 committed
33
34
        )

jimchen90's avatar
jimchen90 committed
35
    def forward(self, specgram: Tensor) -> Tensor:
36
        r"""Pass the input through the ResBlock layer.
jimchen90's avatar
jimchen90 committed
37
        Args:
38
            specgram (Tensor): the input sequence to the ResBlock layer (n_batch, n_freq, n_time).
jimchen90's avatar
jimchen90 committed
39

jimchen90's avatar
jimchen90 committed
40
41
        Return:
            Tensor shape: (n_batch, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
42
43
        """

jimchen90's avatar
jimchen90 committed
44
        return self.resblock_model(specgram) + specgram
jimchen90's avatar
jimchen90 committed
45
46


47
class MelResNet(nn.Module):
jimchen90's avatar
jimchen90 committed
48
    r"""MelResNet layer uses a stack of ResBlocks on spectrogram.
jimchen90's avatar
jimchen90 committed
49
50

    Args:
51
52
53
54
55
        n_res_block: the number of ResBlock in stack. (Default: ``10``)
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
        n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
        n_output: the number of output dimensions of melresnet. (Default: ``128``)
        kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
jimchen90's avatar
jimchen90 committed
56
57

    Examples
58
        >>> melresnet = MelResNet()
jimchen90's avatar
jimchen90 committed
59
60
        >>> input = torch.rand(10, 128, 512)  # a random spectrogram
        >>> output = melresnet(input)  # shape: (10, 128, 508)
jimchen90's avatar
jimchen90 committed
61
62
    """

jimchen90's avatar
jimchen90 committed
63
64
65
66
67
68
    def __init__(self,
                 n_res_block: int = 10,
                 n_freq: int = 128,
                 n_hidden: int = 128,
                 n_output: int = 128,
                 kernel_size: int = 5) -> None:
jimchen90's avatar
jimchen90 committed
69
70
        super().__init__()

71
        ResBlocks = [ResBlock(n_hidden) for _ in range(n_res_block)]
jimchen90's avatar
jimchen90 committed
72
73

        self.melresnet_model = nn.Sequential(
jimchen90's avatar
jimchen90 committed
74
75
            nn.Conv1d(in_channels=n_freq, out_channels=n_hidden, kernel_size=kernel_size, bias=False),
            nn.BatchNorm1d(n_hidden),
jimchen90's avatar
jimchen90 committed
76
77
            nn.ReLU(inplace=True),
            *ResBlocks,
jimchen90's avatar
jimchen90 committed
78
            nn.Conv1d(in_channels=n_hidden, out_channels=n_output, kernel_size=1)
jimchen90's avatar
jimchen90 committed
79
80
        )

jimchen90's avatar
jimchen90 committed
81
    def forward(self, specgram: Tensor) -> Tensor:
82
        r"""Pass the input through the MelResNet layer.
jimchen90's avatar
jimchen90 committed
83
        Args:
84
            specgram (Tensor): the input sequence to the MelResNet layer (n_batch, n_freq, n_time).
jimchen90's avatar
jimchen90 committed
85

jimchen90's avatar
jimchen90 committed
86
87
        Return:
            Tensor shape: (n_batch, n_output, n_time - kernel_size + 1)
jimchen90's avatar
jimchen90 committed
88
89
        """

jimchen90's avatar
jimchen90 committed
90
        return self.melresnet_model(specgram)
jimchen90's avatar
jimchen90 committed
91
92


93
class Stretch2d(nn.Module):
jimchen90's avatar
jimchen90 committed
94
95
96
97
98
99
100
    r"""Upscale the frequency and time dimensions of a spectrogram.

    Args:
        time_scale: the scale factor in time dimension
        freq_scale: the scale factor in frequency dimension

    Examples
101
        >>> stretch2d = Stretch2d(time_scale=10, freq_scale=5)
jimchen90's avatar
jimchen90 committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115

        >>> input = torch.rand(10, 100, 512)  # a random spectrogram
        >>> output = stretch2d(input)  # shape: (10, 500, 5120)
    """

    def __init__(self,
                 time_scale: int,
                 freq_scale: int) -> None:
        super().__init__()

        self.freq_scale = freq_scale
        self.time_scale = time_scale

    def forward(self, specgram: Tensor) -> Tensor:
116
        r"""Pass the input through the Stretch2d layer.
jimchen90's avatar
jimchen90 committed
117
118

        Args:
119
            specgram (Tensor): the input sequence to the Stretch2d layer (..., n_freq, n_time).
jimchen90's avatar
jimchen90 committed
120
121
122
123
124
125
126
127

        Return:
            Tensor shape: (..., n_freq * freq_scale, n_time * time_scale)
        """

        return specgram.repeat_interleave(self.freq_scale, -2).repeat_interleave(self.time_scale, -1)


128
class UpsampleNetwork(nn.Module):
jimchen90's avatar
jimchen90 committed
129
130
131
    r"""Upscale the dimensions of a spectrogram.

    Args:
132
133
134
135
136
137
        upsample_scales: the list of upsample scales.
        n_res_block: the number of ResBlock in stack. (Default: ``10``)
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
        n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
        n_output: the number of output dimensions of melresnet. (Default: ``128``)
        kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
jimchen90's avatar
jimchen90 committed
138
139

    Examples
140
        >>> upsamplenetwork = UpsampleNetwork(upsample_scales=[4, 4, 16])
jimchen90's avatar
jimchen90 committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        >>> input = torch.rand(10, 128, 10)  # a random spectrogram
        >>> output = upsamplenetwork(input)  # shape: (10, 1536, 128), (10, 1536, 128)
    """

    def __init__(self,
                 upsample_scales: List[int],
                 n_res_block: int = 10,
                 n_freq: int = 128,
                 n_hidden: int = 128,
                 n_output: int = 128,
                 kernel_size: int = 5) -> None:
        super().__init__()

        total_scale = 1
        for upsample_scale in upsample_scales:
            total_scale *= upsample_scale

        self.indent = (kernel_size - 1) // 2 * total_scale
159
160
        self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
        self.resnet_stretch = Stretch2d(total_scale, 1)
jimchen90's avatar
jimchen90 committed
161
162
163

        up_layers = []
        for scale in upsample_scales:
164
            stretch = Stretch2d(scale, 1)
jimchen90's avatar
jimchen90 committed
165
166
167
168
169
170
171
172
173
174
            conv = nn.Conv2d(in_channels=1,
                             out_channels=1,
                             kernel_size=(1, scale * 2 + 1),
                             padding=(0, scale),
                             bias=False)
            conv.weight.data.fill_(1. / (scale * 2 + 1))
            up_layers.append(stretch)
            up_layers.append(conv)
        self.upsample_layers = nn.Sequential(*up_layers)

175
    def forward(self, specgram: Tensor) -> Tuple[Tensor, Tensor]:
176
        r"""Pass the input through the UpsampleNetwork layer.
jimchen90's avatar
jimchen90 committed
177
178

        Args:
179
            specgram (Tensor): the input sequence to the UpsampleNetwork layer (n_batch, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        Return:
            Tensor shape: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale),
                          (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
        where total_scale is the product of all elements in upsample_scales.
        """

        resnet_output = self.resnet(specgram).unsqueeze(1)
        resnet_output = self.resnet_stretch(resnet_output)
        resnet_output = resnet_output.squeeze(1)

        specgram = specgram.unsqueeze(1)
        upsampling_output = self.upsample_layers(specgram)
        upsampling_output = upsampling_output.squeeze(1)[:, :, self.indent:-self.indent]

        return upsampling_output, resnet_output
jimchen90's avatar
jimchen90 committed
196
197


198
class WaveRNN(nn.Module):
jimchen90's avatar
jimchen90 committed
199
200
201
202
203
204
205
206
    r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.

    The original implementation was introduced in
    `"Efficient Neural Audio Synthesis" <https://arxiv.org/pdf/1802.08435.pdf>`_.
    The input channels of waveform and spectrogram have to be 1. The product of
    `upsample_scales` must equal `hop_length`.

    Args:
207
208
209
210
211
212
213
214
215
216
        upsample_scales: the list of upsample scales.
        n_classes: the number of output classes.
        hop_length: the number of samples between the starts of consecutive frames.
        n_res_block: the number of ResBlock in stack. (Default: ``10``)
        n_rnn: the dimension of RNN layer. (Default: ``512``)
        n_fc: the dimension of fully connected layer. (Default: ``512``)
        kernel_size: the number of kernel size in the first Conv1d layer. (Default: ``5``)
        n_freq: the number of bins in a spectrogram. (Default: ``128``)
        n_hidden: the number of hidden dimensions of resblock. (Default: ``128``)
        n_output: the number of output dimensions of melresnet. (Default: ``128``)
jimchen90's avatar
jimchen90 committed
217
218

    Example
219
        >>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200)
jimchen90's avatar
jimchen90 committed
220
221
222
223
        >>> waveform, sample_rate = torchaudio.load(file)
        >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length)
        >>> specgram = MelSpectrogram(sample_rate)(waveform)  # shape: (n_batch, n_channel, n_freq, n_time)
        >>> output = wavernn(waveform, specgram)
224
        >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
jimchen90's avatar
jimchen90 committed
225
226
227
228
    """

    def __init__(self,
                 upsample_scales: List[int],
229
                 n_classes: int,
jimchen90's avatar
jimchen90 committed
230
231
232
233
234
235
236
                 hop_length: int,
                 n_res_block: int = 10,
                 n_rnn: int = 512,
                 n_fc: int = 512,
                 kernel_size: int = 5,
                 n_freq: int = 128,
                 n_hidden: int = 128,
237
                 n_output: int = 128) -> None:
jimchen90's avatar
jimchen90 committed
238
239
240
241
242
243
        super().__init__()

        self.kernel_size = kernel_size
        self.n_rnn = n_rnn
        self.n_aux = n_output // 4
        self.hop_length = hop_length
244
        self.n_classes = n_classes
jimchen90's avatar
jimchen90 committed
245
246
247
248
249
250
251

        total_scale = 1
        for upsample_scale in upsample_scales:
            total_scale *= upsample_scale
        if total_scale != self.hop_length:
            raise ValueError(f"Expected: total_scale == hop_length, but found {total_scale} != {hop_length}")

252
253
254
255
256
257
        self.upsample = UpsampleNetwork(upsample_scales,
                                        n_res_block,
                                        n_freq,
                                        n_hidden,
                                        n_output,
                                        kernel_size)
jimchen90's avatar
jimchen90 committed
258
259
260
261
262
263
264
265
266
267
268
269
270
        self.fc = nn.Linear(n_freq + self.n_aux + 1, n_rnn)

        self.rnn1 = nn.GRU(n_rnn, n_rnn, batch_first=True)
        self.rnn2 = nn.GRU(n_rnn + self.n_aux, n_rnn, batch_first=True)

        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)

        self.fc1 = nn.Linear(n_rnn + self.n_aux, n_fc)
        self.fc2 = nn.Linear(n_fc + self.n_aux, n_fc)
        self.fc3 = nn.Linear(n_fc, self.n_classes)

    def forward(self, waveform: Tensor, specgram: Tensor) -> Tensor:
271
        r"""Pass the input through the WaveRNN model.
jimchen90's avatar
jimchen90 committed
272
273

        Args:
274
275
            waveform: the input waveform to the WaveRNN layer (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
            specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
jimchen90's avatar
jimchen90 committed
276
277

        Return:
278
            Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
jimchen90's avatar
jimchen90 committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        """

        assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
        assert specgram.size(1) == 1, 'Require the input channel of specgram is 1'
        # remove channel dimension until the end
        waveform, specgram = waveform.squeeze(1), specgram.squeeze(1)

        batch_size = waveform.size(0)
        h1 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
        h2 = torch.zeros(1, batch_size, self.n_rnn, dtype=waveform.dtype, device=waveform.device)
        # output of upsample:
        # specgram: (n_batch, n_freq, (n_time - kernel_size + 1) * total_scale)
        # aux: (n_batch, n_output, (n_time - kernel_size + 1) * total_scale)
        specgram, aux = self.upsample(specgram)
        specgram = specgram.transpose(1, 2)
        aux = aux.transpose(1, 2)

        aux_idx = [self.n_aux * i for i in range(5)]
        a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
        a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
        a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
        a4 = aux[:, :, aux_idx[3]:aux_idx[4]]

        x = torch.cat([waveform.unsqueeze(-1), specgram, a1], dim=-1)
        x = self.fc(x)
        res = x
        x, _ = self.rnn1(x, h1)

        x = x + res
        res = x
        x = torch.cat([x, a2], dim=-1)
        x, _ = self.rnn2(x, h2)

        x = x + res
        x = torch.cat([x, a3], dim=-1)
        x = self.fc1(x)
        x = self.relu1(x)

        x = torch.cat([x, a4], dim=-1)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)

        # bring back channel dimension
        return x.unsqueeze(1)