transforms.py 24.3 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
from __future__ import absolute_import, division, print_function, unicode_literals
4
from warnings import warn
5
import math
David Pollack's avatar
David Pollack committed
6
import torch
7
from typing import Optional
Jason Lian's avatar
pre  
Jason Lian committed
8
from . import functional as F
jamarshon's avatar
jamarshon committed
9
from .compliance import kaldi
10

Jason Lian's avatar
Jason Lian committed
11

12
13
__all__ = [
    'Spectrogram',
14
    'GriffinLim',
15
    'AmplitudeToDB',
16
17
18
19
20
21
    'MelScale',
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
22
23
24
25
    'ComplexNorm',
    'TimeStretch',
    'FrequencyMasking',
    'TimeMasking',
26
27
28
]


29
class Spectrogram(torch.nn.Module):
30
    r"""Create a spectrogram from a audio signal
31
32

    Args:
33
34
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        win_length (int): Window size. (Default: ``n_fft``)
35
        hop_length (int, optional): Length of hop between STFT windows. (
36
37
            Default: ``win_length // 2``)
        pad (int): Two sided padding of signal. (Default: ``0``)
38
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
39
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
40
        power (float): Exponent for the magnitude spectrogram,
41
42
43
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
44
    """
45
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
46

47
48
    def __init__(self, n_fft=400, win_length=None, hop_length=None,
                 pad=0, window_fn=torch.hann_window,
49
                 power=2., normalized=False, wkwargs=None):
50
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
51
        self.n_fft = n_fft
52
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
53
        # number of frequecies due to onesided=True in torch.stft
54
55
56
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
57
        self.register_buffer('window', window)
58
        self.pad = pad
PCerles's avatar
PCerles committed
59
        self.power = power
60
        self.normalized = normalized
61

62
63
    def forward(self, waveform):
        r"""
64
        Args:
Vincent QB's avatar
Vincent QB committed
65
            waveform (torch.Tensor): Tensor of audio of dimension (..., time)
66
67

        Returns:
Vincent QB's avatar
Vincent QB committed
68
69
            torch.Tensor: Dimension (..., freq, time), where freq is
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
70
            Fourier bins, and time is the number of window hops (n_frame).
71
        """
72
73
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
74
75


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

    .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

    .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.

    .. [3] D. W. Griffin and J. S. Lim,
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        n_iter (int, optional): Number of iteration for phase recovery process.
        win_length (int): Window size. (Default: ``n_fft``)
        hop_length (int, optional): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
101
        power (float): Exponent for the magnitude spectrogram,
102
103
104
105
106
107
108
109
110
111
112
113
114
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
        length (int, optional): Array length of the expected output. (Default: ``None``)
        rand_init (bool): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
    """
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
                     'length', 'momentum', 'rand_init']

    def __init__(self, n_fft=400, n_iter=32, win_length=None, hop_length=None,
115
                 window_fn=torch.hann_window, power=2., normalized=False, wkwargs=None,
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
                 momentum=0.99, length=None, rand_init=True):
        super(GriffinLim, self).__init__()

        assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
        assert momentum > 0, 'momentum=%s < 0' % momentum

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.normalized = normalized
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

    def forward(self, S):
        return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


139
class AmplitudeToDB(torch.jit.ScriptModule):
Vincent QB's avatar
Vincent QB committed
140
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
141

142
    This output depends on the maximum value in the input tensor, and so
143
144
145
146
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
147
        stype (str): scale of input tensor ('power' or 'magnitude'). The
148
            power being the elementwise square of the magnitude. (Default: ``'power'``)
149
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
150
            is 80. (Default: ``None``)
151
152
153
154
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

    def __init__(self, stype='power', top_db=None):
155
        super(AmplitudeToDB, self).__init__()
156
        self.stype = stype
157
158
159
160
161
162
163
164
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
        self.top_db = torch.jit.Attribute(top_db, Optional[float])
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

165
    def forward(self, x):
166
167
168
169
        r"""Numerically stable implementation from Librosa
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
170
            x (torch.Tensor): Input tensor before being converted to decibel scale
171
172

        Returns:
173
            torch.Tensor: Output tensor in decibel scale
174
        """
175
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
176
177


178
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
179
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
180
    matrix.  This uses triangular filter banks.
181

182
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
183

184
    Args:
185
186
187
188
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
189
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
190
            if None is given.  See ``n_fft`` in :class:`Spectrogram`.
191
    """
192
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
193

194
    def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
195
        super(MelScale, self).__init__()
196
        self.n_mels = n_mels
197
198
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
199
        self.f_min = f_min
200
201
202

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

203
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
engineerchuan's avatar
engineerchuan committed
204
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
205
        self.register_buffer('fb', fb)
206

207
208
209
    def forward(self, specgram):
        r"""
        Args:
Vincent QB's avatar
Vincent QB committed
210
            specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time)
211
212

        Returns:
Vincent QB's avatar
Vincent QB committed
213
            torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
214
        """
Vincent QB's avatar
Vincent QB committed
215
216
217
218
219

        # pack batch
        shape = specgram.size()
        specgram = specgram.reshape(-1, shape[-2], shape[-1])

220
        if self.fb.numel() == 0:
engineerchuan's avatar
engineerchuan committed
221
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
222
223
224
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
225

226
227
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
228
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
229
230
231
232

        # unpack batch
        mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])

233
        return mel_specgram
234

235

236
class MelSpectrogram(torch.nn.Module):
237
238
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
239

240
    Sources
241
242
243
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
244

245
    Args:
246
247
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        win_length (int): Window size. (Default: ``n_fft``)
248
        hop_length (int, optional): Length of hop between STFT windows. (
249
250
251
252
253
254
            Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``None``)
        pad (int): Two sided padding of signal. (Default: ``0``)
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
255
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
256
257
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
258

259
    Example
260
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
261
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
262
    """
263
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
264

265
266
267
268
269
270
271
272
273
    def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
274
        self.f_max = f_max
275
276
277
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
278
                                       pad=self.pad, window_fn=window_fn, power=2.,
279
                                       normalized=False, wkwargs=wkwargs)
280
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
281

282
283
    def forward(self, waveform):
        r"""
284
        Args:
Vincent QB's avatar
Vincent QB committed
285
            waveform (torch.Tensor): Tensor of audio of dimension (..., time)
286
287

        Returns:
Vincent QB's avatar
Vincent QB committed
288
            torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
289
        """
290
291
292
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
293
294


295
class MFCC(torch.nn.Module):
296
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal
PCerles's avatar
PCerles committed
297

298
299
300
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
301

302
303
304
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
305

306
    Args:
307
308
309
310
311
312
313
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``)
        norm (str, optional): norm to use. (Default: ``'ortho'``)
        log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default:
            ``False``)
        melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
314
    """
315
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
316

317
    def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
PCerles's avatar
PCerles committed
318
                 melkwargs=None):
319
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
320
321
322
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
323
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
324
325
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
326
        self.norm = norm
327
        self.top_db = 80.0
328
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
329
330

        if melkwargs is not None:
331
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
332
        else:
333
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
334
335
336

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
337
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
338
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
339
340
        self.log_mels = log_mels

341
342
    def forward(self, waveform):
        r"""
PCerles's avatar
PCerles committed
343
        Args:
Vincent QB's avatar
Vincent QB committed
344
            waveform (torch.Tensor): Tensor of audio of dimension (..., time)
PCerles's avatar
PCerles committed
345
346

        Returns:
Vincent QB's avatar
Vincent QB committed
347
            torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time)
PCerles's avatar
PCerles committed
348
        """
Vincent QB's avatar
Vincent QB committed
349
350
351
352
353

        # pack batch
        shape = waveform.size()
        waveform = waveform.reshape(-1, shape[-1])

354
        mel_specgram = self.MelSpectrogram(waveform)
355
356
        if self.log_mels:
            log_offset = 1e-6
357
            mel_specgram = torch.log(mel_specgram + log_offset)
358
        else:
359
            mel_specgram = self.amplitude_to_DB(mel_specgram)
360
361
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
362
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
363
364
365
366

        # unpack batch
        mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])

367
        return mfcc
368
369


370
class MuLawEncoding(torch.nn.Module):
371
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
372
373
374
375
376
377
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
378
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
379
    """
380
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
381
382

    def __init__(self, quantization_channels=256):
383
        super(MuLawEncoding, self).__init__()
384
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
385

386
    def forward(self, x):
387
        r"""
David Pollack's avatar
David Pollack committed
388
        Args:
389
            x (torch.Tensor): A signal to be encoded
David Pollack's avatar
David Pollack committed
390
391

        Returns:
392
            x_mu (torch.Tensor): An encoded signal
David Pollack's avatar
David Pollack committed
393
        """
394
        return F.mu_law_encoding(x, self.quantization_channels)
395

Soumith Chintala's avatar
Soumith Chintala committed
396

397
class MuLawDecoding(torch.nn.Module):
398
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
399
400
401
402
403
404
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
405
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
406
    """
407
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
408
409

    def __init__(self, quantization_channels=256):
410
        super(MuLawDecoding, self).__init__()
411
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
412

413
    def forward(self, x_mu):
414
        r"""
David Pollack's avatar
David Pollack committed
415
        Args:
416
            x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
David Pollack's avatar
David Pollack committed
417
418

        Returns:
419
            torch.Tensor: The signal decoded
David Pollack's avatar
David Pollack committed
420
        """
421
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
422
423
424


class Resample(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
425
    r"""Resample a signal from one frequency to another. A resampling method can
jamarshon's avatar
jamarshon committed
426
427
428
    be given.

    Args:
429
430
431
        orig_freq (float): The original frequency of the signal. (Default: ``16000``)
        new_freq (float): The desired frequency. (Default: ``16000``)
        resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
432
    """
433
    def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
jamarshon's avatar
jamarshon committed
434
435
436
437
438
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

439
440
    def forward(self, waveform):
        r"""
jamarshon's avatar
jamarshon committed
441
        Args:
Vincent QB's avatar
Vincent QB committed
442
            waveform (torch.Tensor): The input signal of dimension (..., time)
jamarshon's avatar
jamarshon committed
443
444

        Returns:
Vincent QB's avatar
Vincent QB committed
445
            torch.Tensor: Output signal of dimension (..., time)
jamarshon's avatar
jamarshon committed
446
447
        """
        if self.resampling_method == 'sinc_interpolation':
448
            return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
jamarshon's avatar
jamarshon committed
449
450

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
451
452


453
class ComplexNorm(torch.nn.Module):
454
455
456
457
458
459
460
461
462
463
464
465
466
    r"""Compute the norm of complex tensor input
    Args:
        power (float): Power of the norm. Defaults to `1.0`.
    """
    __constants__ = ['power']

    def __init__(self, power=1.0):
        super(ComplexNorm, self).__init__()
        self.power = power

    def forward(self, complex_tensor):
        r"""
        Args:
467
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
468
        Returns:
469
            Tensor: norm of the input tensor, shape of `(..., )`
470
471
472
473
        """
        return F.complex_norm(complex_tensor, self.power)


474
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
475
476
477
478
479
480
481
482
483
484
485
486
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
        win_length (int): The window length used for computing delta.
    """
    __constants__ = ['win_length']

    def __init__(self, win_length=5, mode="replicate"):
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
487
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
488
489
490
491

    def forward(self, specgram):
        r"""
        Args:
Vincent QB's avatar
Vincent QB committed
492
            specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
493
494

        Returns:
Vincent QB's avatar
Vincent QB committed
495
            deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
496
497
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524


class TimeStretch(torch.jit.ScriptModule):
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
        hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
        fixed_rate (float): rate to speed up or slow down by.
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

    def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
        phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
        self.phase_advance = torch.jit.Attribute(phase_advance, torch.Tensor)

    def forward(self, complex_specgrams, overriding_rate=None):
        # type: (Tensor, Optional[float]) -> Tensor
        r"""
        Args:
525
            complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
526
527
528
529
            overriding_rate (float or None): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``

        Returns:
530
            (Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
531
        """
532
        assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
533
534
535
536
537
538
539
540
541
542
543
544

        if overriding_rate is None:
            rate = self.fixed_rate
            if rate is None:
                raise ValueError("If no fixed_rate is specified"
                                 ", must pass a valid rate to the forward method.")
        else:
            rate = overriding_rate

        if rate == 1.0:
            return complex_specgrams

545
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
546
547
548


class _AxisMasking(torch.nn.Module):
549
550
    r"""Apply masking to a spectrogram.

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    Args:
        mask_param (int): Maximum possible length of the mask
        axis: What dimension the mask is applied on
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

    def __init__(self, mask_param, axis, iid_masks):

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

    def forward(self, specgram, mask_value=0.):
        # type: (Tensor, float) -> Tensor
        r"""
        Args:
569
            specgram (torch.Tensor): Tensor of dimension (..., freq, time)
570
571

        Returns:
572
            torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
573
574
575
576
577
578
        """

        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
579
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
580
581
582


class FrequencyMasking(_AxisMasking):
583
584
    r"""Apply masking to a spectrogram in the frequency domain.

585
586
587
588
589
590
591
592
593
594
595
596
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
        iid_masks (bool): weather to apply the same mask to all
            the examples/channels in the batch. (Default: False)
    """

    def __init__(self, freq_mask_param, iid_masks=False):
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
597
598
    r"""Apply masking to a spectrogram in the time domain.

599
600
601
602
603
604
605
606
607
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
        iid_masks (bool): weather to apply the same mask to all
            the examples/channels in the batch. Defaults to False.
    """

    def __init__(self, time_mask_param, iid_masks=False):
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)