transforms.py 34.7 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
5
6
from typing import Callable, Optional
from warnings import warn

David Pollack's avatar
David Pollack committed
7
import torch
8
from torch import Tensor
9
10
from torchaudio import functional as F
from torchaudio.compliance import kaldi
11

Jason Lian's avatar
Jason Lian committed
12

13
14
__all__ = [
    'Spectrogram',
15
    'GriffinLim',
16
    'AmplitudeToDB',
17
    'MelScale',
moto's avatar
moto committed
18
    'InverseMelScale',
19
20
21
22
23
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
24
25
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
26
    'Fade',
27
28
    'FrequencyMasking',
    'TimeMasking',
29
30
31
]


32
class Spectrogram(torch.nn.Module):
33
    r"""Create a spectrogram from a audio signal.
34
35

    Args:
36
37
38
39
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
40
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
41
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
42
43
44
45
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
46
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
47
    """
48
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
49

50
51
52
53
54
55
56
57
58
    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None) -> None:
59
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
60
        self.n_fft = n_fft
61
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
62
        # number of frequecies due to onesided=True in torch.stft
63
64
65
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
66
        self.register_buffer('window', window)
67
        self.pad = pad
PCerles's avatar
PCerles committed
68
        self.power = power
69
        self.normalized = normalized
70

71
    def forward(self, waveform: Tensor) -> Tensor:
72
        r"""
73
        Args:
74
            waveform (Tensor): Tensor of audio of dimension (..., time).
75
76

        Returns:
77
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
78
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
79
            Fourier bins, and time is the number of window hops (n_frame).
80
        """
81
82
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
83
84


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

    .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

    .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.

    .. [3] D. W. Griffin and J. S. Lim,
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
103
104
105
106
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
107
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
108
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
109
        power (float, optional): Exponent for the magnitude spectrogram,
110
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
111
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
112
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
113
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
114
            Setting this to 0 recovers the original Griffin-Lim method.
115
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
116
        length (int, optional): Array length of the expected output. (Default: ``None``)
117
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
118
119
120
121
    """
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
                     'length', 'momentum', 'rand_init']

122
123
124
125
126
127
128
129
130
131
132
133
    def __init__(self,
                 n_fft: int = 400,
                 n_iter: int = 32,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: float = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None,
                 momentum: float = 0.99,
                 length: Optional[int] = None,
                 rand_init: bool = True) -> None:
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        super(GriffinLim, self).__init__()

        assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
        assert momentum > 0, 'momentum=%s < 0' % momentum

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.normalized = normalized
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

151
152
153
154
155
156
157
158
159
160
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
            specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
            where freq is ``n_fft // 2 + 1``.

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
161
162
163
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


164
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
165
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
166

167
    This output depends on the maximum value in the input tensor, and so
168
169
170
171
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
172
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
173
            power being the elementwise square of the magnitude. (Default: ``'power'``)
174
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
175
            is 80. (Default: ``None``)
176
177
178
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

179
    def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
180
        super(AmplitudeToDB, self).__init__()
181
        self.stype = stype
182
183
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
184
        self.top_db = top_db
185
186
187
188
189
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

190
    def forward(self, x: Tensor) -> Tensor:
191
        r"""Numerically stable implementation from Librosa.
192
193
194
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
195
            x (Tensor): Input tensor before being converted to decibel scale.
196
197

        Returns:
198
            Tensor: Output tensor in decibel scale.
199
        """
200
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
201
202


203
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
204
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
205
    matrix.  This uses triangular filter banks.
206

207
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
208

209
    Args:
210
211
212
213
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
214
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
215
            if None is given.  See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
216
    """
217
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
218

219
220
221
222
223
224
    def __init__(self,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 n_stft: Optional[int] = None) -> None:
225
        super(MelScale, self).__init__()
226
        self.n_mels = n_mels
227
228
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
229
        self.f_min = f_min
230
231
232

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

233
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
engineerchuan's avatar
engineerchuan committed
234
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
235
        self.register_buffer('fb', fb)
236

237
    def forward(self, specgram: Tensor) -> Tensor:
238
239
        r"""
        Args:
240
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
241
242

        Returns:
243
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
244
        """
Vincent QB's avatar
Vincent QB committed
245
246
247

        # pack batch
        shape = specgram.size()
Vincent QB's avatar
Vincent QB committed
248
        specgram = specgram.view(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
249

250
        if self.fb.numel() == 0:
engineerchuan's avatar
engineerchuan committed
251
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
252
253
254
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
255

256
257
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
258
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
259
260

        # unpack batch
Vincent QB's avatar
Vincent QB committed
261
        mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
262

263
        return mel_specgram
264

265

moto's avatar
moto committed
266
267
268
269
270
271
272
273
274
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
275
276
277
278
279
280
281
282
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
moto's avatar
moto committed
283
284
285
286
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

287
288
289
290
291
292
293
294
295
296
    def __init__(self,
                 n_stft: int,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 max_iter: int = 100000,
                 tolerance_loss: float = 1e-5,
                 tolerance_change: float = 1e-8,
                 sgdargs: Optional[dict] = None) -> None:
moto's avatar
moto committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
        self.register_buffer('fb', fb)

312
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
313
314
        r"""
        Args:
315
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
316
317

        Returns:
318
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


358
class MelSpectrogram(torch.nn.Module):
359
360
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
361

362
    Sources
363
364
365
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
366

367
    Args:
368
369
370
371
372
373
374
375
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
376
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
377
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
378
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
379

380
    Example
381
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
382
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
383
    """
384
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
385

386
387
388
389
390
391
392
393
394
395
396
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 pad: int = 0,
                 n_mels: int = 128,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
397
398
399
400
401
402
403
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
404
        self.f_max = f_max
405
406
407
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
408
                                       pad=self.pad, window_fn=window_fn, power=2.,
409
                                       normalized=False, wkwargs=wkwargs)
410
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
411

412
    def forward(self, waveform: Tensor) -> Tensor:
413
        r"""
414
        Args:
415
            waveform (Tensor): Tensor of audio of dimension (..., time).
416
417

        Returns:
418
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
419
        """
420
421
422
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
423
424


425
class MFCC(torch.nn.Module):
426
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
427

428
429
430
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
431

432
433
434
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
435

436
    Args:
437
438
439
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
440
        norm (str, optional): norm to use. (Default: ``'ortho'``)
441
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
442
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
443
    """
444
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
445

446
447
448
449
450
451
452
    def __init__(self,
                 sample_rate: int = 16000,
                 n_mfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_mels: bool = False,
                 melkwargs: Optional[dict] = None) -> None:
453
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
454
455
456
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
457
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
458
459
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
460
        self.norm = norm
461
        self.top_db = 80.0
462
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
463
464

        if melkwargs is not None:
465
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
466
        else:
467
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
468
469
470

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
471
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
472
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
473
474
        self.log_mels = log_mels

475
    def forward(self, waveform: Tensor) -> Tensor:
476
        r"""
PCerles's avatar
PCerles committed
477
        Args:
478
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
479
480

        Returns:
481
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
482
        """
Vincent QB's avatar
Vincent QB committed
483
484
485

        # pack batch
        shape = waveform.size()
Vincent QB's avatar
Vincent QB committed
486
        waveform = waveform.view(-1, shape[-1])
Vincent QB's avatar
Vincent QB committed
487

488
        mel_specgram = self.MelSpectrogram(waveform)
489
490
        if self.log_mels:
            log_offset = 1e-6
491
            mel_specgram = torch.log(mel_specgram + log_offset)
492
        else:
493
            mel_specgram = self.amplitude_to_DB(mel_specgram)
494
495
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
496
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
497
498

        # unpack batch
Vincent QB's avatar
Vincent QB committed
499
        mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
Vincent QB's avatar
Vincent QB committed
500

501
        return mfcc
502
503


504
class MuLawEncoding(torch.nn.Module):
505
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
506
507
508
509
510
511
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
512
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
513
    """
514
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
515

516
    def __init__(self, quantization_channels: int = 256) -> None:
517
        super(MuLawEncoding, self).__init__()
518
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
519

520
    def forward(self, x: Tensor) -> Tensor:
521
        r"""
David Pollack's avatar
David Pollack committed
522
        Args:
523
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
524
525

        Returns:
526
            x_mu (Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
527
        """
528
        return F.mu_law_encoding(x, self.quantization_channels)
529

Soumith Chintala's avatar
Soumith Chintala committed
530

531
class MuLawDecoding(torch.nn.Module):
532
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
533
534
535
536
537
538
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
539
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
540
    """
541
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
542

543
    def __init__(self, quantization_channels: int = 256) -> None:
544
        super(MuLawDecoding, self).__init__()
545
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
546

547
    def forward(self, x_mu: Tensor) -> Tensor:
548
        r"""
David Pollack's avatar
David Pollack committed
549
        Args:
550
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
551
552

        Returns:
553
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
554
        """
555
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
556
557
558


class Resample(torch.nn.Module):
559
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
560
561

    Args:
562
563
564
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
565
    """
566

567
568
569
570
    def __init__(self,
                 orig_freq: int = 16000,
                 new_freq: int = 16000,
                 resampling_method: str = 'sinc_interpolation') -> None:
jamarshon's avatar
jamarshon committed
571
572
573
574
575
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

576
    def forward(self, waveform: Tensor) -> Tensor:
577
        r"""
jamarshon's avatar
jamarshon committed
578
        Args:
579
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
580
581

        Returns:
582
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
583
584
        """
        if self.resampling_method == 'sinc_interpolation':
Vincent QB's avatar
Vincent QB committed
585
586
587
588
589
590
591
592
593
594
595

            # pack batch
            shape = waveform.size()
            waveform = waveform.view(-1, shape[-1])

            waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

            # unpack batch
            waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

            return waveform
jamarshon's avatar
jamarshon committed
596
597

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
598
599


600
class ComplexNorm(torch.nn.Module):
601
602
    r"""Compute the norm of complex tensor input.

603
    Args:
604
        power (float, optional): Power of the norm. (Default: to ``1.0``)
605
606
607
    """
    __constants__ = ['power']

608
    def __init__(self, power: float = 1.0) -> None:
609
610
611
        super(ComplexNorm, self).__init__()
        self.power = power

612
    def forward(self, complex_tensor: Tensor) -> Tensor:
613
614
        r"""
        Args:
615
616
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

617
        Returns:
618
            Tensor: norm of the input tensor, shape of `(..., )`.
619
620
621
622
        """
        return F.complex_norm(complex_tensor, self.power)


623
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
624
625
626
627
628
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
629
630
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
631
632
633
    """
    __constants__ = ['win_length']

634
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
635
636
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
637
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
638

639
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
640
641
        r"""
        Args:
642
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
643
644

        Returns:
645
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
646
647
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
648
649


650
class TimeStretch(torch.nn.Module):
651
652
653
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
654
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
655
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
656
        fixed_rate (float or None, optional): rate to speed up or slow down by.
657
658
659
660
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

661
662
663
664
    def __init__(self,
                 hop_length: Optional[int] = None,
                 n_freq: int = 201,
                 fixed_rate: Optional[float] = None) -> None:
665
666
667
668
669
670
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
671
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
672

673
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
674
675
        r"""
        Args:
676
677
678
            complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
679
680

        Returns:
681
            Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
682
        """
683
        assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
684
685
686
687
688
689
690
691
692
693
694
695

        if overriding_rate is None:
            rate = self.fixed_rate
            if rate is None:
                raise ValueError("If no fixed_rate is specified"
                                 ", must pass a valid rate to the forward method.")
        else:
            rate = overriding_rate

        if rate == 1.0:
            return complex_specgrams

696
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
697
698


Tomás Osório's avatar
Tomás Osório committed
699
700
701
702
703
704
705
706
707
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
708
709
710
711
    def __init__(self,
                 fade_in_len: int = 0,
                 fade_out_len: int = 0,
                 fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
712
713
714
715
716
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

717
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
718
719
        r"""
        Args:
720
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
721
722

        Returns:
723
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
724
725
726
727
728
        """
        waveform_length = waveform.size()[-1]

        return self._fade_in(waveform_length) * self._fade_out(waveform_length) * waveform

729
    def _fade_in(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

750
    def _fade_out(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


772
class _AxisMasking(torch.nn.Module):
773
774
    r"""Apply masking to a spectrogram.

775
    Args:
776
777
778
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
779
780
781
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

782
    def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
783
784
785
786
787
788

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

789
    def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
790
791
        r"""
        Args:
792
            specgram (Tensor): Tensor of dimension (..., freq, time).
793
            mask_value (float): Value to assign to the masked columns.
794
795

        Returns:
796
            Tensor: Masked spectrogram of dimensions (..., freq, time).
797
798
799
800
801
802
        """

        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
803
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
804
805
806


class FrequencyMasking(_AxisMasking):
807
808
    r"""Apply masking to a spectrogram in the frequency domain.

809
810
811
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
rvdmaazen's avatar
rvdmaazen committed
812
        iid_masks (bool, optional): whether to apply the same mask to all
813
            the examples/channels in the batch. (Default: ``False``)
814
815
    """

816
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
817
818
819
820
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
821
822
    r"""Apply masking to a spectrogram in the time domain.

823
824
825
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
rvdmaazen's avatar
rvdmaazen committed
826
        iid_masks (bool, optional): whether to apply the same mask to all
827
            the examples/channels in the batch. (Default: ``False``)
828
829
    """

830
    def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
831
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Tomás Osório's avatar
Tomás Osório committed
832
833
834
835
836
837
838
839
840
841
842
843
844


class Vol(torch.nn.Module):
    r"""Add a volume to an waveform.

    Args:
        gain (float): Interpreted according to the given gain_type:
            If `gain_type’ = ‘amplitude’, `gain’ is a positive amplitude ratio.
            If `gain_type’ = ‘power’, `gain’ is a power (voltage squared).
            If `gain_type’ = ‘db’, `gain’ is in decibels.
        gain_type (str, optional): Type of gain. One of: ‘amplitude’, ‘power’, ‘db’ (Default: ``"amplitude"``)
    """

845
    def __init__(self, gain: float, gain_type: str = 'amplitude'):
Tomás Osório's avatar
Tomás Osório committed
846
847
848
849
850
851
852
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

        if gain_type in ['amplitude', 'power'] and gain < 0:
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

853
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
854
855
        r"""
        Args:
856
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
857
858

        Returns:
859
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
860
861
862
863
864
865
866
867
868
869
870
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)