transforms.py 36.3 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
5
6
from typing import Callable, Optional
from warnings import warn

David Pollack's avatar
David Pollack committed
7
import torch
8
from torch import Tensor
9
10
from torchaudio import functional as F
from torchaudio.compliance import kaldi
11

Jason Lian's avatar
Jason Lian committed
12

13
14
__all__ = [
    'Spectrogram',
15
    'GriffinLim',
16
    'AmplitudeToDB',
17
    'MelScale',
moto's avatar
moto committed
18
    'InverseMelScale',
19
20
21
22
23
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
24
25
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
26
    'Fade',
27
28
    'FrequencyMasking',
    'TimeMasking',
wanglong001's avatar
wanglong001 committed
29
    'SlidingWindowCmn',
30
31
32
]


33
class Spectrogram(torch.nn.Module):
34
    r"""Create a spectrogram from a audio signal.
35
36

    Args:
37
38
39
40
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
41
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
42
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
43
44
45
46
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
47
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
48
    """
49
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
50

51
52
53
54
55
56
57
58
59
    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None) -> None:
60
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
61
        self.n_fft = n_fft
62
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
63
        # number of frequecies due to onesided=True in torch.stft
64
65
66
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
67
        self.register_buffer('window', window)
68
        self.pad = pad
PCerles's avatar
PCerles committed
69
        self.power = power
70
        self.normalized = normalized
71

72
    def forward(self, waveform: Tensor) -> Tensor:
73
        r"""
74
        Args:
75
            waveform (Tensor): Tensor of audio of dimension (..., time).
76
77

        Returns:
78
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
79
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
80
            Fourier bins, and time is the number of window hops (n_frame).
81
        """
82
83
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
84
85


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

    .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

    .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.

    .. [3] D. W. Griffin and J. S. Lim,
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
104
105
106
107
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
108
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
109
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
110
        power (float, optional): Exponent for the magnitude spectrogram,
111
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
112
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
113
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
114
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
115
            Setting this to 0 recovers the original Griffin-Lim method.
116
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
117
        length (int, optional): Array length of the expected output. (Default: ``None``)
118
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
119
120
121
122
    """
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
                     'length', 'momentum', 'rand_init']

123
124
125
126
127
128
129
130
131
132
133
134
    def __init__(self,
                 n_fft: int = 400,
                 n_iter: int = 32,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: float = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None,
                 momentum: float = 0.99,
                 length: Optional[int] = None,
                 rand_init: bool = True) -> None:
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        super(GriffinLim, self).__init__()

        assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
        assert momentum > 0, 'momentum=%s < 0' % momentum

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.normalized = normalized
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

152
153
154
155
156
157
158
159
160
161
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
            specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
            where freq is ``n_fft // 2 + 1``.

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
162
163
164
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


165
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
166
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
167

168
    This output depends on the maximum value in the input tensor, and so
169
170
171
172
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
173
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
174
            power being the elementwise square of the magnitude. (Default: ``'power'``)
175
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
176
            is 80. (Default: ``None``)
177
178
179
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

180
    def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
181
        super(AmplitudeToDB, self).__init__()
182
        self.stype = stype
183
184
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
185
        self.top_db = top_db
186
187
188
189
190
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

191
    def forward(self, x: Tensor) -> Tensor:
192
        r"""Numerically stable implementation from Librosa.
193
194
195
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
196
            x (Tensor): Input tensor before being converted to decibel scale.
197
198

        Returns:
199
            Tensor: Output tensor in decibel scale.
200
        """
201
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
202
203


204
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
205
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
206
    matrix.  This uses triangular filter banks.
207

208
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
209

210
    Args:
211
212
213
214
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
215
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
216
            if None is given.  See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
217
    """
218
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
219

220
221
222
223
224
225
    def __init__(self,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 n_stft: Optional[int] = None) -> None:
226
        super(MelScale, self).__init__()
227
        self.n_mels = n_mels
228
229
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
230
        self.f_min = f_min
231
232
233

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

234
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
engineerchuan's avatar
engineerchuan committed
235
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
236
        self.register_buffer('fb', fb)
237

238
    def forward(self, specgram: Tensor) -> Tensor:
239
240
        r"""
        Args:
241
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
242
243

        Returns:
244
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
245
        """
Vincent QB's avatar
Vincent QB committed
246
247
248

        # pack batch
        shape = specgram.size()
Vincent QB's avatar
Vincent QB committed
249
        specgram = specgram.view(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
250

251
        if self.fb.numel() == 0:
engineerchuan's avatar
engineerchuan committed
252
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
253
254
255
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
256

257
258
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
259
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
260
261

        # unpack batch
Vincent QB's avatar
Vincent QB committed
262
        mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
263

264
        return mel_specgram
265

266

moto's avatar
moto committed
267
268
269
270
271
272
273
274
275
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
276
277
278
279
280
281
282
283
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
moto's avatar
moto committed
284
285
286
287
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

288
289
290
291
292
293
294
295
296
297
    def __init__(self,
                 n_stft: int,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 max_iter: int = 100000,
                 tolerance_loss: float = 1e-5,
                 tolerance_change: float = 1e-8,
                 sgdargs: Optional[dict] = None) -> None:
moto's avatar
moto committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
        self.register_buffer('fb', fb)

313
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
314
315
        r"""
        Args:
316
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
317
318

        Returns:
319
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


359
class MelSpectrogram(torch.nn.Module):
360
361
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
362

363
    Sources
364
365
366
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
367

368
    Args:
369
370
371
372
373
374
375
376
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
377
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
378
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
379
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
380

381
    Example
382
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
383
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
384
    """
385
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
386

387
388
389
390
391
392
393
394
395
396
397
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 pad: int = 0,
                 n_mels: int = 128,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
398
399
400
401
402
403
404
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
405
        self.f_max = f_max
406
407
408
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
409
                                       pad=self.pad, window_fn=window_fn, power=2.,
410
                                       normalized=False, wkwargs=wkwargs)
411
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
412

413
    def forward(self, waveform: Tensor) -> Tensor:
414
        r"""
415
        Args:
416
            waveform (Tensor): Tensor of audio of dimension (..., time).
417
418

        Returns:
419
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
420
        """
421
422
423
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
424
425


426
class MFCC(torch.nn.Module):
427
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
428

429
430
431
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
432

433
434
435
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
436

437
    Args:
438
439
440
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
441
        norm (str, optional): norm to use. (Default: ``'ortho'``)
442
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
443
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
444
    """
445
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
446

447
448
449
450
451
452
453
    def __init__(self,
                 sample_rate: int = 16000,
                 n_mfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_mels: bool = False,
                 melkwargs: Optional[dict] = None) -> None:
454
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
455
456
457
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
458
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
459
460
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
461
        self.norm = norm
462
        self.top_db = 80.0
463
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
464
465

        if melkwargs is not None:
466
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
467
        else:
468
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
469
470
471

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
472
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
473
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
474
475
        self.log_mels = log_mels

476
    def forward(self, waveform: Tensor) -> Tensor:
477
        r"""
PCerles's avatar
PCerles committed
478
        Args:
479
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
480
481

        Returns:
482
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
483
        """
Vincent QB's avatar
Vincent QB committed
484
485
486

        # pack batch
        shape = waveform.size()
Vincent QB's avatar
Vincent QB committed
487
        waveform = waveform.view(-1, shape[-1])
Vincent QB's avatar
Vincent QB committed
488

489
        mel_specgram = self.MelSpectrogram(waveform)
490
491
        if self.log_mels:
            log_offset = 1e-6
492
            mel_specgram = torch.log(mel_specgram + log_offset)
493
        else:
494
            mel_specgram = self.amplitude_to_DB(mel_specgram)
495
496
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
497
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
498
499

        # unpack batch
Vincent QB's avatar
Vincent QB committed
500
        mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
Vincent QB's avatar
Vincent QB committed
501

502
        return mfcc
503
504


505
class MuLawEncoding(torch.nn.Module):
506
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
507
508
509
510
511
512
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
513
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
514
    """
515
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
516

517
    def __init__(self, quantization_channels: int = 256) -> None:
518
        super(MuLawEncoding, self).__init__()
519
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
520

521
    def forward(self, x: Tensor) -> Tensor:
522
        r"""
David Pollack's avatar
David Pollack committed
523
        Args:
524
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
525
526

        Returns:
527
            x_mu (Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
528
        """
529
        return F.mu_law_encoding(x, self.quantization_channels)
530

Soumith Chintala's avatar
Soumith Chintala committed
531

532
class MuLawDecoding(torch.nn.Module):
533
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
534
535
536
537
538
539
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
540
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
541
    """
542
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
543

544
    def __init__(self, quantization_channels: int = 256) -> None:
545
        super(MuLawDecoding, self).__init__()
546
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
547

548
    def forward(self, x_mu: Tensor) -> Tensor:
549
        r"""
David Pollack's avatar
David Pollack committed
550
        Args:
551
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
552
553

        Returns:
554
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
555
        """
556
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
557
558
559


class Resample(torch.nn.Module):
560
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
561
562

    Args:
563
564
565
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
566
    """
567

568
569
570
571
    def __init__(self,
                 orig_freq: int = 16000,
                 new_freq: int = 16000,
                 resampling_method: str = 'sinc_interpolation') -> None:
jamarshon's avatar
jamarshon committed
572
573
574
575
576
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

577
    def forward(self, waveform: Tensor) -> Tensor:
578
        r"""
jamarshon's avatar
jamarshon committed
579
        Args:
580
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
581
582

        Returns:
583
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
584
585
        """
        if self.resampling_method == 'sinc_interpolation':
Vincent QB's avatar
Vincent QB committed
586
587
588
589
590
591
592
593
594
595
596

            # pack batch
            shape = waveform.size()
            waveform = waveform.view(-1, shape[-1])

            waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

            # unpack batch
            waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

            return waveform
jamarshon's avatar
jamarshon committed
597
598

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
599
600


601
class ComplexNorm(torch.nn.Module):
602
603
    r"""Compute the norm of complex tensor input.

604
    Args:
605
        power (float, optional): Power of the norm. (Default: to ``1.0``)
606
607
608
    """
    __constants__ = ['power']

609
    def __init__(self, power: float = 1.0) -> None:
610
611
612
        super(ComplexNorm, self).__init__()
        self.power = power

613
    def forward(self, complex_tensor: Tensor) -> Tensor:
614
615
        r"""
        Args:
616
617
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

618
        Returns:
619
            Tensor: norm of the input tensor, shape of `(..., )`.
620
621
622
623
        """
        return F.complex_norm(complex_tensor, self.power)


624
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
625
626
627
628
629
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
630
631
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
632
633
634
    """
    __constants__ = ['win_length']

635
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
636
637
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
638
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
639

640
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
641
642
        r"""
        Args:
643
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
644
645

        Returns:
646
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
647
648
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
649
650


651
class TimeStretch(torch.nn.Module):
652
653
654
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
655
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
656
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
657
        fixed_rate (float or None, optional): rate to speed up or slow down by.
658
659
660
661
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

662
663
664
665
    def __init__(self,
                 hop_length: Optional[int] = None,
                 n_freq: int = 201,
                 fixed_rate: Optional[float] = None) -> None:
666
667
668
669
670
671
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
672
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
673

674
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
675
676
        r"""
        Args:
677
678
679
            complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
680
681

        Returns:
682
            Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
683
        """
684
        assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
685
686
687
688
689
690
691
692
693
694
695
696

        if overriding_rate is None:
            rate = self.fixed_rate
            if rate is None:
                raise ValueError("If no fixed_rate is specified"
                                 ", must pass a valid rate to the forward method.")
        else:
            rate = overriding_rate

        if rate == 1.0:
            return complex_specgrams

697
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
698
699


Tomás Osório's avatar
Tomás Osório committed
700
701
702
703
704
705
706
707
708
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
709
710
711
712
    def __init__(self,
                 fade_in_len: int = 0,
                 fade_out_len: int = 0,
                 fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
713
714
715
716
717
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

718
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
719
720
        r"""
        Args:
721
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
722
723

        Returns:
724
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
725
726
        """
        waveform_length = waveform.size()[-1]
moto's avatar
moto committed
727
728
729
        device = waveform.device
        return self._fade_in(waveform_length).to(device) * \
            self._fade_out(waveform_length).to(device) * waveform
Tomás Osório's avatar
Tomás Osório committed
730

731
    def _fade_in(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

752
    def _fade_out(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


774
class _AxisMasking(torch.nn.Module):
775
776
    r"""Apply masking to a spectrogram.

777
    Args:
778
779
780
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
781
782
783
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

784
    def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
785
786
787
788
789
790

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

791
    def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
792
793
        r"""
        Args:
794
            specgram (Tensor): Tensor of dimension (..., freq, time).
795
            mask_value (float): Value to assign to the masked columns.
796
797

        Returns:
798
            Tensor: Masked spectrogram of dimensions (..., freq, time).
799
800
801
802
803
804
        """

        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
805
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
806
807
808


class FrequencyMasking(_AxisMasking):
809
810
    r"""Apply masking to a spectrogram in the frequency domain.

811
812
813
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
rvdmaazen's avatar
rvdmaazen committed
814
        iid_masks (bool, optional): whether to apply the same mask to all
815
            the examples/channels in the batch. (Default: ``False``)
816
817
    """

818
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
819
820
821
822
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
823
824
    r"""Apply masking to a spectrogram in the time domain.

825
826
827
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
rvdmaazen's avatar
rvdmaazen committed
828
        iid_masks (bool, optional): whether to apply the same mask to all
829
            the examples/channels in the batch. (Default: ``False``)
830
831
    """

832
    def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
833
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Tomás Osório's avatar
Tomás Osório committed
834
835
836
837
838
839
840
841
842
843
844
845
846


class Vol(torch.nn.Module):
    r"""Add a volume to an waveform.

    Args:
        gain (float): Interpreted according to the given gain_type:
            If `gain_type’ = ‘amplitude’, `gain’ is a positive amplitude ratio.
            If `gain_type’ = ‘power’, `gain’ is a power (voltage squared).
            If `gain_type’ = ‘db’, `gain’ is in decibels.
        gain_type (str, optional): Type of gain. One of: ‘amplitude’, ‘power’, ‘db’ (Default: ``"amplitude"``)
    """

847
    def __init__(self, gain: float, gain_type: str = 'amplitude'):
Tomás Osório's avatar
Tomás Osório committed
848
849
850
851
852
853
854
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

        if gain_type in ['amplitude', 'power'] and gain < 0:
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

855
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
856
857
        r"""
        Args:
858
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
859
860

        Returns:
861
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
862
863
864
865
866
867
868
869
870
871
872
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)
wanglong001's avatar
wanglong001 committed
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909


class SlidingWindowCmn(torch.nn.Module):
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
    """

    def __init__(self,
                 cmn_window: int = 600,
                 min_cmn_window: int = 100,
                 center: bool = False,
                 norm_vars: bool = False) -> None:
        super().__init__()
        self.cmn_window = cmn_window
        self.min_cmn_window = min_cmn_window
        self.center = center
        self.norm_vars = norm_vars

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Tensor of audio of dimension (..., time).
        """
        cmn_waveform = F.sliding_window_cmn(
            waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
        return cmn_waveform