transforms.py 31.2 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
from __future__ import absolute_import, division, print_function, unicode_literals
4
from warnings import warn
5
import math
David Pollack's avatar
David Pollack committed
6
import torch
7
from typing import Optional
8
9
from torchaudio import functional as F
from torchaudio.compliance import kaldi
10

Jason Lian's avatar
Jason Lian committed
11

12
13
__all__ = [
    'Spectrogram',
14
    'GriffinLim',
15
    'AmplitudeToDB',
16
    'MelScale',
moto's avatar
moto committed
17
    'InverseMelScale',
18
19
20
21
22
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
23
24
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
25
    'Fade',
26
27
    'FrequencyMasking',
    'TimeMasking',
28
29
30
]


31
class Spectrogram(torch.nn.Module):
32
    r"""Create a spectrogram from a audio signal.
33
34

    Args:
35
36
37
38
39
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
40
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
41
42
43
44
45
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
46
    """
47
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
48

49
50
    def __init__(self, n_fft=400, win_length=None, hop_length=None,
                 pad=0, window_fn=torch.hann_window,
51
                 power=2., normalized=False, wkwargs=None):
52
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
53
        self.n_fft = n_fft
54
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
55
        # number of frequecies due to onesided=True in torch.stft
56
57
58
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
59
        self.register_buffer('window', window)
60
        self.pad = pad
PCerles's avatar
PCerles committed
61
        self.power = power
62
        self.normalized = normalized
63

64
65
    def forward(self, waveform):
        r"""
66
        Args:
67
            waveform (torch.Tensor): Tensor of audio of dimension (..., time).
68
69

        Returns:
Vincent QB's avatar
Vincent QB committed
70
71
            torch.Tensor: Dimension (..., freq, time), where freq is
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
72
            Fourier bins, and time is the number of window hops (n_frame).
73
        """
74
75
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
76
77


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

    .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

    .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.

    .. [3] D. W. Griffin and J. S. Lim,
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
96
97
98
99
100
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
101
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
102
        power (float, optional): Exponent for the magnitude spectrogram,
103
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
104
105
106
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
107
            Setting this to 0 recovers the original Griffin-Lim method.
108
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
109
        length (int, optional): Array length of the expected output. (Default: ``None``)
110
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
111
112
113
114
115
    """
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
                     'length', 'momentum', 'rand_init']

    def __init__(self, n_fft=400, n_iter=32, win_length=None, hop_length=None,
116
                 window_fn=torch.hann_window, power=2., normalized=False, wkwargs=None,
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
                 momentum=0.99, length=None, rand_init=True):
        super(GriffinLim, self).__init__()

        assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
        assert momentum > 0, 'momentum=%s < 0' % momentum

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.normalized = normalized
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

    def forward(self, S):
        return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


140
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
141
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
142

143
    This output depends on the maximum value in the input tensor, and so
144
145
146
147
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
148
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
149
            power being the elementwise square of the magnitude. (Default: ``'power'``)
150
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
151
            is 80. (Default: ``None``)
152
153
154
155
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

    def __init__(self, stype='power', top_db=None):
156
        super(AmplitudeToDB, self).__init__()
157
        self.stype = stype
158
159
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
160
        self.top_db = top_db
161
162
163
164
165
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

166
    def forward(self, x):
167
        r"""Numerically stable implementation from Librosa.
168
169
170
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
171
            x (torch.Tensor): Input tensor before being converted to decibel scale.
172
173

        Returns:
174
            torch.Tensor: Output tensor in decibel scale.
175
        """
176
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
177
178


179
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
180
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
181
    matrix.  This uses triangular filter banks.
182

183
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
184

185
    Args:
186
187
188
189
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
190
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
191
            if None is given.  See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
192
    """
193
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
194

195
    def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
196
        super(MelScale, self).__init__()
197
        self.n_mels = n_mels
198
199
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
200
        self.f_min = f_min
201
202
203

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

204
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
engineerchuan's avatar
engineerchuan committed
205
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
206
        self.register_buffer('fb', fb)
207

208
209
210
    def forward(self, specgram):
        r"""
        Args:
211
            specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time).
212
213

        Returns:
214
            torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
215
        """
Vincent QB's avatar
Vincent QB committed
216
217
218

        # pack batch
        shape = specgram.size()
Vincent QB's avatar
Vincent QB committed
219
        specgram = specgram.view(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
220

221
        if self.fb.numel() == 0:
engineerchuan's avatar
engineerchuan committed
222
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
223
224
225
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
226

227
228
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
229
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
230
231

        # unpack batch
Vincent QB's avatar
Vincent QB committed
232
        mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
233

234
        return mel_specgram
235

236

moto's avatar
moto committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int): Maximum number of optimization iterations.
        tolerance_loss (float): Value of loss to stop optimization at.
        tolerance_change (float): Difference in losses to stop optimization at.
        sgdargs (dict): Arguments for the SGD optimizer.
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

    def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000,
                 tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None):
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
        self.register_buffer('fb', fb)

    def forward(self, melspec):
        r"""
        Args:
            melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)

        Returns:
            torch.Tensor: Linear scale spectrogram of size (..., freq, time)
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


321
class MelSpectrogram(torch.nn.Module):
322
323
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
324

325
    Sources
326
327
328
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
329

330
    Args:
331
332
333
334
335
336
337
338
339
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
340
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
341
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
342

343
    Example
344
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
345
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
346
    """
347
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
348

349
350
351
352
353
354
355
356
357
    def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
358
        self.f_max = f_max
359
360
361
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
362
                                       pad=self.pad, window_fn=window_fn, power=2.,
363
                                       normalized=False, wkwargs=wkwargs)
364
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
365

366
367
    def forward(self, waveform):
        r"""
368
        Args:
369
            waveform (torch.Tensor): Tensor of audio of dimension (..., time).
370
371

        Returns:
372
            torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
373
        """
374
375
376
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
377
378


379
class MFCC(torch.nn.Module):
380
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
381

382
383
384
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
385

386
387
388
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
389

390
    Args:
391
392
393
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
394
        norm (str, optional): norm to use. (Default: ``'ortho'``)
395
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
396
        melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
397
    """
398
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
399

400
    def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
PCerles's avatar
PCerles committed
401
                 melkwargs=None):
402
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
403
404
405
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
406
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
407
408
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
409
        self.norm = norm
410
        self.top_db = 80.0
411
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
412
413

        if melkwargs is not None:
414
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
415
        else:
416
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
417
418
419

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
420
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
421
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
422
423
        self.log_mels = log_mels

424
425
    def forward(self, waveform):
        r"""
PCerles's avatar
PCerles committed
426
        Args:
427
            waveform (torch.Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
428
429

        Returns:
430
            torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
431
        """
Vincent QB's avatar
Vincent QB committed
432
433
434

        # pack batch
        shape = waveform.size()
Vincent QB's avatar
Vincent QB committed
435
        waveform = waveform.view(-1, shape[-1])
Vincent QB's avatar
Vincent QB committed
436

437
        mel_specgram = self.MelSpectrogram(waveform)
438
439
        if self.log_mels:
            log_offset = 1e-6
440
            mel_specgram = torch.log(mel_specgram + log_offset)
441
        else:
442
            mel_specgram = self.amplitude_to_DB(mel_specgram)
443
444
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
445
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
446
447

        # unpack batch
Vincent QB's avatar
Vincent QB committed
448
        mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
Vincent QB's avatar
Vincent QB committed
449

450
        return mfcc
451
452


453
class MuLawEncoding(torch.nn.Module):
454
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
455
456
457
458
459
460
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
461
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
462
    """
463
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
464
465

    def __init__(self, quantization_channels=256):
466
        super(MuLawEncoding, self).__init__()
467
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
468

469
    def forward(self, x):
470
        r"""
David Pollack's avatar
David Pollack committed
471
        Args:
472
            x (torch.Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
473
474

        Returns:
475
            x_mu (torch.Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
476
        """
477
        return F.mu_law_encoding(x, self.quantization_channels)
478

Soumith Chintala's avatar
Soumith Chintala committed
479

480
class MuLawDecoding(torch.nn.Module):
481
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
482
483
484
485
486
487
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
488
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
489
    """
490
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
491
492

    def __init__(self, quantization_channels=256):
493
        super(MuLawDecoding, self).__init__()
494
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
495

496
    def forward(self, x_mu):
497
        r"""
David Pollack's avatar
David Pollack committed
498
        Args:
499
            x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
500
501

        Returns:
502
            torch.Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
503
        """
504
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
505
506
507


class Resample(torch.nn.Module):
508
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
509
510

    Args:
511
512
513
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
514
    """
515

516
    def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
jamarshon's avatar
jamarshon committed
517
518
519
520
521
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

522
523
    def forward(self, waveform):
        r"""
jamarshon's avatar
jamarshon committed
524
        Args:
525
            waveform (torch.Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
526
527

        Returns:
528
            torch.Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
529
530
        """
        if self.resampling_method == 'sinc_interpolation':
Vincent QB's avatar
Vincent QB committed
531
532
533
534
535
536
537
538
539
540
541

            # pack batch
            shape = waveform.size()
            waveform = waveform.view(-1, shape[-1])

            waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

            # unpack batch
            waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

            return waveform
jamarshon's avatar
jamarshon committed
542
543

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
544
545


546
class ComplexNorm(torch.nn.Module):
547
548
    r"""Compute the norm of complex tensor input.

549
    Args:
550
        power (float, optional): Power of the norm. (Default: to ``1.0``)
551
552
553
554
555
556
557
558
559
560
    """
    __constants__ = ['power']

    def __init__(self, power=1.0):
        super(ComplexNorm, self).__init__()
        self.power = power

    def forward(self, complex_tensor):
        r"""
        Args:
561
562
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

563
        Returns:
564
            Tensor: norm of the input tensor, shape of `(..., )`.
565
566
567
568
        """
        return F.complex_norm(complex_tensor, self.power)


569
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
570
571
572
573
574
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
575
576
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
577
578
579
580
581
582
    """
    __constants__ = ['win_length']

    def __init__(self, win_length=5, mode="replicate"):
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
583
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
584
585
586
587

    def forward(self, specgram):
        r"""
        Args:
588
            specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
589
590

        Returns:
591
            deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
592
593
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
594
595


596
class TimeStretch(torch.nn.Module):
597
598
599
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
600
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
601
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
602
        fixed_rate (float or None, optional): rate to speed up or slow down by.
603
604
605
606
607
608
609
610
611
612
613
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

    def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
614
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
615
616
617
618
619

    def forward(self, complex_specgrams, overriding_rate=None):
        # type: (Tensor, Optional[float]) -> Tensor
        r"""
        Args:
620
621
622
            complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
623
624

        Returns:
625
            (Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
626
        """
627
        assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
628
629
630
631
632
633
634
635
636
637
638
639

        if overriding_rate is None:
            rate = self.fixed_rate
            if rate is None:
                raise ValueError("If no fixed_rate is specified"
                                 ", must pass a valid rate to the forward method.")
        else:
            rate = overriding_rate

        if rate == 1.0:
            return complex_specgrams

640
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
641
642


Tomás Osório's avatar
Tomás Osório committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
    def __init__(self, fade_in_len=0, fade_out_len=0, fade_shape="linear"):
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

    def forward(self, waveform):
        # type: (Tensor) -> Tensor
        r"""
        Args:
            waveform (torch.Tensor): Tensor of audio of dimension (..., time).

        Returns:
            torch.Tensor: Tensor of audio of dimension (..., time).
        """
        waveform_length = waveform.size()[-1]

        return self._fade_in(waveform_length) * self._fade_out(waveform_length) * waveform

    def _fade_in(self, waveform_length):
        # type: (int) -> Tensor
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

    def _fade_out(self, waveform_length):
        # type: (int) -> Tensor
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


716
class _AxisMasking(torch.nn.Module):
717
718
    r"""Apply masking to a spectrogram.

719
    Args:
720
721
722
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
723
724
725
726
727
728
729
730
731
732
733
734
735
736
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

    def __init__(self, mask_param, axis, iid_masks):

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

    def forward(self, specgram, mask_value=0.):
        # type: (Tensor, float) -> Tensor
        r"""
        Args:
737
738
            specgram (torch.Tensor): Tensor of dimension (..., freq, time).
            mask_value (float): Value to assign to the masked columns.
739
740

        Returns:
741
            torch.Tensor: Masked spectrogram of dimensions (..., freq, time).
742
743
744
745
746
747
        """

        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
748
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
749
750
751


class FrequencyMasking(_AxisMasking):
752
753
    r"""Apply masking to a spectrogram in the frequency domain.

754
755
756
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
757
758
        iid_masks (bool, optional): weather to apply the same mask to all
            the examples/channels in the batch. (Default: ``False``)
759
760
761
762
763
764
765
    """

    def __init__(self, freq_mask_param, iid_masks=False):
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
766
767
    r"""Apply masking to a spectrogram in the time domain.

768
769
770
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
771
772
        iid_masks (bool, optional): weather to apply the same mask to all
            the examples/channels in the batch. (Default: ``False``)
773
774
775
776
    """

    def __init__(self, time_mask_param, iid_masks=False):
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)