"torchvision/models/vision_transformer.py" did not exist on "b9da6db4f923e49ab7431c7b6915037101622611"
transforms.py 56.8 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
import warnings
5
6
from typing import Callable, Optional

David Pollack's avatar
David Pollack committed
7
import torch
8
from torch import Tensor
9
from torchaudio import functional as F
10

11
12
13
14
from .functional.functional import (
    _get_sinc_resample_kernel,
    _apply_sinc_resample_kernel,
)
Jason Lian's avatar
Jason Lian committed
15

16
17
__all__ = [
    'Spectrogram',
18
    'GriffinLim',
19
    'AmplitudeToDB',
20
    'MelScale',
moto's avatar
moto committed
21
    'InverseMelScale',
22
23
    'MelSpectrogram',
    'MFCC',
24
    'LFCC',
25
26
27
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
28
29
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
30
    'Fade',
31
32
    'FrequencyMasking',
    'TimeMasking',
wanglong001's avatar
wanglong001 committed
33
    'SlidingWindowCmn',
Artyom Astafurov's avatar
Artyom Astafurov committed
34
    'Vad',
35
    'SpectralCentroid',
36
37
    'Vol',
    'ComputeDeltas',
38
    'PitchShift',
39
40
41
]


42
class Spectrogram(torch.nn.Module):
43
    r"""Create a spectrogram from a audio signal.
44
45

    Args:
46
47
48
49
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
50
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
51
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
52
53
54
55
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
56
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
57
58
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Caroline Chen's avatar
Caroline Chen committed
59
            (Default: ``True``)
60
        pad_mode (string, optional): controls the padding method used when
Caroline Chen's avatar
Caroline Chen committed
61
            :attr:`center` is ``True``. (Default: ``"reflect"``)
62
        onesided (bool, optional): controls whether to return half of results to
Caroline Chen's avatar
Caroline Chen committed
63
            avoid redundancy (Default: ``True``)
64
        return_complex (bool, optional):
65
66
67
            Indicates whether the resulting complex-valued Tensor should be represented with
            native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
            mimicking complex value with an extra dimension for real and imaginary parts.
68
69
70
71
            (See also ``torch.view_as_real``.)
            This argument is only effective when ``power=None``. It is ignored for
            cases where ``power`` is a number as in those cases, the returned tensor is
            power spectrogram, which is a real-valued tensor.
72
    """
73
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
74

75
76
77
78
79
80
81
82
    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
83
84
85
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
86
                 onesided: bool = True,
87
                 return_complex: bool = True) -> None:
88
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
89
        self.n_fft = n_fft
90
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
91
        # number of frequencies due to onesided=True in torch.stft
92
93
94
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
95
        self.register_buffer('window', window)
96
        self.pad = pad
PCerles's avatar
PCerles committed
97
        self.power = power
98
        self.normalized = normalized
99
100
101
        self.center = center
        self.pad_mode = pad_mode
        self.onesided = onesided
102
        self.return_complex = return_complex
103

104
    def forward(self, waveform: Tensor) -> Tensor:
105
        r"""
106
        Args:
107
            waveform (Tensor): Tensor of audio of dimension (..., time).
108
109

        Returns:
110
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
111
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
112
            Fourier bins, and time is the number of window hops (n_frame).
113
        """
114
115
116
117
118
119
120
121
122
123
124
        return F.spectrogram(
            waveform,
            self.pad,
            self.window,
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.power,
            self.normalized,
            self.center,
            self.pad_mode,
125
126
            self.onesided,
            self.return_complex,
127
        )
128
129


130
131
132
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.

moto's avatar
moto committed
133
    Implementation ported from
134
135
    *librosa* [:footcite:`brian_mcfee-proc-scipy-2015`], *A fast Griffin-Lim algorithm* [:footcite:`6701851`]
    and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
136
137

    Args:
138
139
140
141
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
142
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
143
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
144
        power (float, optional): Exponent for the magnitude spectrogram,
145
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
146
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
147
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
148
            Setting this to 0 recovers the original Griffin-Lim method.
149
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
150
        length (int, optional): Array length of the expected output. (Default: ``None``)
151
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
152
    """
153
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power',
154
155
                     'length', 'momentum', 'rand_init']

156
157
158
159
160
161
162
163
164
165
166
    def __init__(self,
                 n_fft: int = 400,
                 n_iter: int = 32,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: float = 2.,
                 wkwargs: Optional[dict] = None,
                 momentum: float = 0.99,
                 length: Optional[int] = None,
                 rand_init: bool = True) -> None:
167
168
        super(GriffinLim, self).__init__()

169
        assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
170
        assert momentum >= 0, 'momentum={} < 0'.format(momentum)
171
172
173
174
175
176
177
178
179
180
181
182

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

183
184
185
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
186
187
188
            specgram (Tensor):
                A magnitude-only STFT spectrogram of dimension (..., freq, frames)
                where freq is ``n_fft // 2 + 1``.
189
190
191
192
193

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
194
                            self.n_iter, self.momentum, self.length, self.rand_init)
195
196


197
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
198
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
199

200
    This output depends on the maximum value in the input tensor, and so
201
202
203
204
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
205
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
206
            power being the elementwise square of the magnitude. (Default: ``'power'``)
207
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
208
            is 80. (Default: ``None``)
209
210
211
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

212
    def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
213
        super(AmplitudeToDB, self).__init__()
214
        self.stype = stype
215
216
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
217
        self.top_db = top_db
218
219
220
221
222
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

223
    def forward(self, x: Tensor) -> Tensor:
224
        r"""Numerically stable implementation from Librosa.
moto's avatar
moto committed
225
226

        https://librosa.org/doc/latest/generated/librosa.amplitude_to_db.html
227
228

        Args:
229
            x (Tensor): Input tensor before being converted to decibel scale.
230
231

        Returns:
232
            Tensor: Output tensor in decibel scale.
233
        """
234
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
235
236


237
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
238
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
239
    matrix.  This uses triangular filter banks.
240

241
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
242

243
    Args:
244
245
246
247
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
248
249
        n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
        norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
250
        (area normalization). (Default: ``None``)
251
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
252
    """
253
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
254

255
256
257
258
259
    def __init__(self,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
260
                 n_stft: int = 201,
261
262
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
263
        super(MelScale, self).__init__()
264
        self.n_mels = n_mels
265
266
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
267
        self.f_min = f_min
268
        self.norm = norm
269
        self.mel_scale = mel_scale
270

271
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
272
        fb = F.melscale_fbanks(
273
274
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
            self.mel_scale)
275
        self.register_buffer('fb', fb)
276

277
    def forward(self, specgram: Tensor) -> Tensor:
278
279
        r"""
        Args:
280
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
281
282

        Returns:
283
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
284
        """
Vincent QB's avatar
Vincent QB committed
285

286
287
        # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
        mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
Vincent QB's avatar
Vincent QB committed
288

289
        return mel_specgram
290

291

moto's avatar
moto committed
292
293
294
295
296
297
298
299
300
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
301
302
303
304
305
306
307
308
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
309
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
Caroline Chen's avatar
Caroline Chen committed
310
            (area normalization). (Default: ``None``)
311
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
moto's avatar
moto committed
312
313
314
315
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

316
317
318
319
320
321
322
323
324
    def __init__(self,
                 n_stft: int,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 max_iter: int = 100000,
                 tolerance_loss: float = 1e-5,
                 tolerance_change: float = 1e-8,
325
                 sgdargs: Optional[dict] = None,
326
327
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
moto's avatar
moto committed
328
329
330
331
332
333
334
335
336
337
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

338
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
moto's avatar
moto committed
339

340
341
        fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate,
                               norm, mel_scale)
moto's avatar
moto committed
342
343
        self.register_buffer('fb', fb)

344
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
345
346
        r"""
        Args:
347
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
348
349

        Returns:
350
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


390
class MelSpectrogram(torch.nn.Module):
391
392
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
393

394
    Sources
395
396
397
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
398

399
    Args:
400
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
Caroline Chen's avatar
Caroline Chen committed
401
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
402
403
404
405
406
407
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
408
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
409
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
Caroline Chen's avatar
Caroline Chen committed
410
411
412
        power (float, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
413
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
414
415
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Caroline Chen's avatar
Caroline Chen committed
416
            (Default: ``True``)
417
        pad_mode (string, optional): controls the padding method used when
Caroline Chen's avatar
Caroline Chen committed
418
            :attr:`center` is ``True``. (Default: ``"reflect"``)
419
        onesided (bool, optional): controls whether to return half of results to
Caroline Chen's avatar
Caroline Chen committed
420
            avoid redundancy. (Default: ``True``)
421
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
Caroline Chen's avatar
Caroline Chen committed
422
            (area normalization). (Default: ``None``)
423
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
424

425
    Example
nateanl's avatar
nateanl committed
426
427
428
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.MelSpectrogram(sample_rate)
        >>> mel_specgram = transform(waveform)  # (channel, n_mels, time)
429
    """
430
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
431

432
433
434
435
436
437
438
439
440
441
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 pad: int = 0,
                 n_mels: int = 128,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
442
                 power: float = 2.,
443
                 normalized: bool = False,
444
445
446
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
447
                 onesided: bool = True,
448
449
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
450
451
452
453
454
455
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
456
457
        self.power = power
        self.normalized = normalized
458
        self.n_mels = n_mels  # number of mel frequency bins
459
        self.f_max = f_max
460
461
462
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
463
                                       pad=self.pad, window_fn=window_fn, power=self.power,
464
465
                                       normalized=self.normalized, wkwargs=wkwargs,
                                       center=center, pad_mode=pad_mode, onesided=onesided)
466
467
468
469
470
471
472
473
474
        self.mel_scale = MelScale(
            self.n_mels,
            self.sample_rate,
            self.f_min,
            self.f_max,
            self.n_fft // 2 + 1,
            norm,
            mel_scale
        )
475

476
    def forward(self, waveform: Tensor) -> Tensor:
477
        r"""
478
        Args:
479
            waveform (Tensor): Tensor of audio of dimension (..., time).
480
481

        Returns:
482
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
483
        """
484
485
486
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
487
488


489
class MFCC(torch.nn.Module):
490
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
491

492
493
494
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
495

496
497
498
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
499

500
    Args:
501
502
503
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
504
        norm (str, optional): norm to use. (Default: ``'ortho'``)
505
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
506
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
507
    """
508
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
509

510
511
512
513
514
515
516
    def __init__(self,
                 sample_rate: int = 16000,
                 n_mfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_mels: bool = False,
                 melkwargs: Optional[dict] = None) -> None:
517
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
518
519
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
520
            raise ValueError('DCT type not supported: {}'.format(dct_type))
521
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
522
523
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
524
        self.norm = norm
525
        self.top_db = 80.0
526
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
527

528
529
        melkwargs = melkwargs or {}
        self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
530
531
532

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
533
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
534
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
535
536
        self.log_mels = log_mels

537
    def forward(self, waveform: Tensor) -> Tensor:
538
        r"""
PCerles's avatar
PCerles committed
539
        Args:
540
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
541
542

        Returns:
543
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
544
        """
545
        mel_specgram = self.MelSpectrogram(waveform)
546
547
        if self.log_mels:
            log_offset = 1e-6
548
            mel_specgram = torch.log(mel_specgram + log_offset)
549
        else:
550
            mel_specgram = self.amplitude_to_DB(mel_specgram)
Vincent QB's avatar
Vincent QB committed
551

552
553
        # (..., time, n_mels) dot (n_mels, n_mfcc) -> (..., n_nfcc, time)
        mfcc = torch.matmul(mel_specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
554
        return mfcc
555
556


557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
class LFCC(torch.nn.Module):
    r"""Create the linear-frequency cepstrum coefficients from an audio signal.

    By default, this calculates the LFCC on the DB-scaled linear filtered spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.

    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_filter (int, optional): Number of linear filters to apply. (Default: ``128``)
        n_lfcc (int, optional): Number of lfc coefficients to retain. (Default: ``40``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
        norm (str, optional): norm to use. (Default: ``'ortho'``)
        log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``)
        speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``)
    """
    __constants__ = ['sample_rate', 'n_filter', 'n_lfcc', 'dct_type', 'top_db', 'log_lf']

    def __init__(self,
                 sample_rate: int = 16000,
                 n_filter: int = 128,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 n_lfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_lf: bool = False,
                 speckwargs: Optional[dict] = None) -> None:
        super(LFCC, self).__init__()
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported: {}'.format(dct_type))
        self.sample_rate = sample_rate
        self.f_min = f_min
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
        self.n_filter = n_filter
        self.n_lfcc = n_lfcc
        self.dct_type = dct_type
        self.norm = norm
        self.top_db = 80.0
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)

        speckwargs = speckwargs or {}
        self.Spectrogram = Spectrogram(**speckwargs)

        if self.n_lfcc > self.Spectrogram.n_fft:
            raise ValueError('Cannot select more LFCC coefficients than # fft bins')

        filter_mat = F.linear_fbanks(
            n_freqs=self.Spectrogram.n_fft // 2 + 1,
            f_min=self.f_min,
            f_max=self.f_max,
            n_filter=self.n_filter,
            sample_rate=self.sample_rate,
        )
        self.register_buffer("filter_mat", filter_mat)

        dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm)
        self.register_buffer('dct_mat', dct_mat)
        self.log_lf = log_lf

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Linear Frequency Cepstral Coefficients of size (..., ``n_lfcc``, time).
        """
        specgram = self.Spectrogram(waveform)

        # (..., time, freq) dot (freq, n_filter) -> (..., n_filter, time)
        specgram = torch.matmul(specgram.transpose(-1, -2), self.filter_mat).transpose(-1, -2)

        if self.log_lf:
            log_offset = 1e-6
            specgram = torch.log(specgram + log_offset)
        else:
            specgram = self.amplitude_to_DB(specgram)

        # (..., time, n_filter) dot (n_filter, n_lfcc) -> (..., n_lfcc, time)
        lfcc = torch.matmul(specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
        return lfcc


648
class MuLawEncoding(torch.nn.Module):
649
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
650
651
652
653
654
655
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
656
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
657
658
659
660
661
662

    Example
       >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
       >>> transform = torchaudio.transforms.MuLawEncoding(quantization_channels=512)
       >>> mulawtrans = transform(waveform)

David Pollack's avatar
David Pollack committed
663
    """
664
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
665

666
    def __init__(self, quantization_channels: int = 256) -> None:
667
        super(MuLawEncoding, self).__init__()
668
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
669

670
    def forward(self, x: Tensor) -> Tensor:
671
        r"""
David Pollack's avatar
David Pollack committed
672
        Args:
673
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
674
675

        Returns:
676
            x_mu (Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
677
        """
678
        return F.mu_law_encoding(x, self.quantization_channels)
679

Soumith Chintala's avatar
Soumith Chintala committed
680

681
class MuLawDecoding(torch.nn.Module):
682
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
683
684
685
686
687
688
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
689
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
690
    """
691
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
692

693
    def __init__(self, quantization_channels: int = 256) -> None:
694
        super(MuLawDecoding, self).__init__()
695
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
696

697
    def forward(self, x_mu: Tensor) -> Tensor:
698
        r"""
David Pollack's avatar
David Pollack committed
699
        Args:
700
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
701
702

        Returns:
703
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
704
        """
705
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
706
707
708


class Resample(torch.nn.Module):
709
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
710

Caroline Chen's avatar
Caroline Chen committed
711
712
713
714
715
716
    Note:
        If resampling on waveforms of higher precision than float32, there may be a small loss of precision
        because the kernel is cached once as float32. If high precision resampling is important for your application,
        the functional form will retain higher precision, but run slower because it does not cache the kernel.
        Alternatively, you could rewrite a transform that caches a higher precision kernel.

jamarshon's avatar
jamarshon committed
717
    Args:
718
719
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
720
        resampling_method (str, optional): The resampling method to use.
721
            Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
722
        lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
723
            but less efficient. (Default: ``6``)
724
725
        rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
            Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
726
        beta (float or None): The shape parameter used for kaiser window.
727
728
729
730
731
732
733
        dtype (torch.device, optional):
            Determnines the precision that resampling kernel is pre-computed and cached. If not provided,
            kernel is computed with ``torch.float64`` then cached as ``torch.float32``.
            If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and
            cached as ``torch.float64``. If you use resample with lower precision, then instead of providing this
            providing this argument, please use ``Resample.to(dtype)``, so that the kernel generation is still
            carried out on ``torch.float64``.
734
735
736
737
738

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.Resample(sample_rate, sample_rate/10)
        >>> waveform = transform(waveform)
jamarshon's avatar
jamarshon committed
739
    """
740

741
742
743
744
745
746
747
748
749
750
751
752
    def __init__(
            self,
            orig_freq: float = 16000,
            new_freq: float = 16000,
            resampling_method: str = 'sinc_interpolation',
            lowpass_filter_width: int = 6,
            rolloff: float = 0.99,
            beta: Optional[float] = None,
            *,
            dtype: Optional[torch.dtype] = None,
    ) -> None:
        super().__init__()
753

jamarshon's avatar
jamarshon committed
754
755
        self.orig_freq = orig_freq
        self.new_freq = new_freq
756
        self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
jamarshon's avatar
jamarshon committed
757
        self.resampling_method = resampling_method
758
759
        self.lowpass_filter_width = lowpass_filter_width
        self.rolloff = rolloff
760
        self.beta = beta
jamarshon's avatar
jamarshon committed
761

762
        if self.orig_freq != self.new_freq:
763
764
765
766
            kernel, self.width = _get_sinc_resample_kernel(
                self.orig_freq, self.new_freq, self.gcd,
                self.lowpass_filter_width, self.rolloff,
                self.resampling_method, beta, dtype=dtype)
767
            self.register_buffer('kernel', kernel)
768

769
    def forward(self, waveform: Tensor) -> Tensor:
770
        r"""
jamarshon's avatar
jamarshon committed
771
        Args:
772
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
773
774

        Returns:
775
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
776
        """
777
778
        if self.orig_freq == self.new_freq:
            return waveform
779
780
781
        return _apply_sinc_resample_kernel(
            waveform, self.orig_freq, self.new_freq, self.gcd,
            self.kernel, self.width)
Vincent QB's avatar
Vincent QB committed
782
783


784
class ComplexNorm(torch.nn.Module):
785
786
    r"""Compute the norm of complex tensor input.

787
    Args:
788
        power (float, optional): Power of the norm. (Default: to ``1.0``)
789
790
791
792
793

    Example
        >>> complex_tensor = ... #  Tensor shape of (…, complex=2)
        >>> transform = transforms.ComplexNorm(power=2)
        >>> complex_norm = transform(complex_tensor)
794
795
796
    """
    __constants__ = ['power']

797
    def __init__(self, power: float = 1.0) -> None:
798
799
800
801
802
803
804
805
        warnings.warn(
            'torchaudio.transforms.ComplexNorm has been deprecated '
            'and will be removed from future release.'
            'Please convert the input Tensor to complex type with `torch.view_as_complex` then '
            'use `torch.abs` and `torch.angle`. '
            'Please refer to https://github.com/pytorch/audio/issues/1337 '
            "for more details about torchaudio's plan to migrate to native complex type."
        )
806
807
808
        super(ComplexNorm, self).__init__()
        self.power = power

809
    def forward(self, complex_tensor: Tensor) -> Tensor:
810
811
        r"""
        Args:
812
813
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

814
        Returns:
815
            Tensor: norm of the input tensor, shape of `(..., )`.
816
817
818
819
        """
        return F.complex_norm(complex_tensor, self.power)


820
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
821
822
823
824
825
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
826
827
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
828
829
830
    """
    __constants__ = ['win_length']

831
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
832
833
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
834
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
835

836
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
837
838
        r"""
        Args:
839
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
840
841

        Returns:
842
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
843
844
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
845
846


847
class TimeStretch(torch.nn.Module):
848
849
850
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
851
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
852
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
853
        fixed_rate (float or None, optional): rate to speed up or slow down by.
854
855
856
857
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

858
859
860
861
    def __init__(self,
                 hop_length: Optional[int] = None,
                 n_freq: int = 201,
                 fixed_rate: Optional[float] = None) -> None:
862
863
864
865
866
867
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
868
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
869

870
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
871
872
        r"""
        Args:
873
874
875
            complex_specgrams (Tensor):
                Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
                or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
876
877
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
878
879

        Returns:
880
881
882
            Tensor:
                Stretched spectrogram. The resulting tensor is of the same dtype as the input
                spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
883
884
        """
        if overriding_rate is None:
885
886
887
            if self.fixed_rate is None:
                raise ValueError(
                    "If no fixed_rate is specified, must pass a valid rate to the forward method.")
888
889
890
            rate = self.fixed_rate
        else:
            rate = overriding_rate
891
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
892
893


Tomás Osório's avatar
Tomás Osório committed
894
895
896
897
898
899
900
901
902
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
903

904
905
906
907
    def __init__(self,
                 fade_in_len: int = 0,
                 fade_out_len: int = 0,
                 fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
908
909
910
911
912
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

913
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
914
915
        r"""
        Args:
916
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
917
918

        Returns:
919
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
920
921
        """
        waveform_length = waveform.size()[-1]
moto's avatar
moto committed
922
923
924
        device = waveform.device
        return self._fade_in(waveform_length).to(device) * \
            self._fade_out(waveform_length).to(device) * waveform
Tomás Osório's avatar
Tomás Osório committed
925

926
    def _fade_in(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

947
    def _fade_out(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


969
class _AxisMasking(torch.nn.Module):
970
971
    r"""Apply masking to a spectrogram.

972
    Args:
973
974
975
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
976
            This option is applicable only when the input tensor is 4D.
977
978
979
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

980
    def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
981
982
983
984
985
986

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

987
    def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
988
989
        r"""
        Args:
990
            specgram (Tensor): Tensor of dimension (..., freq, time).
991
            mask_value (float): Value to assign to the masked columns.
992
993

        Returns:
994
            Tensor: Masked spectrogram of dimensions (..., freq, time).
995
996
997
998
999
        """
        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
1000
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
1001
1002
1003


class FrequencyMasking(_AxisMasking):
1004
1005
    r"""Apply masking to a spectrogram in the frequency domain.

1006
1007
1008
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
1009
1010
1011
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
1012
    """
1013

1014
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
1015
1016
1017
1018
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
1019
1020
    r"""Apply masking to a spectrogram in the time domain.

1021
1022
1023
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
1024
1025
1026
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
1027
    """
1028

1029
    def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
1030
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Tomás Osório's avatar
Tomás Osório committed
1031
1032
1033
1034
1035
1036
1037


class Vol(torch.nn.Module):
    r"""Add a volume to an waveform.

    Args:
        gain (float): Interpreted according to the given gain_type:
Vincent QB's avatar
Vincent QB committed
1038
1039
1040
1041
            If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio.
            If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared).
            If ``gain_type`` = ``db``, ``gain`` is in decibels.
        gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
Tomás Osório's avatar
Tomás Osório committed
1042
1043
    """

1044
    def __init__(self, gain: float, gain_type: str = 'amplitude'):
Tomás Osório's avatar
Tomás Osório committed
1045
1046
1047
1048
1049
1050
1051
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

        if gain_type in ['amplitude', 'power'] and gain < 0:
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

1052
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
1053
1054
        r"""
        Args:
1055
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
1056
1057

        Returns:
1058
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)
wanglong001's avatar
wanglong001 committed
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106


class SlidingWindowCmn(torch.nn.Module):
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
    """

    def __init__(self,
                 cmn_window: int = 600,
                 min_cmn_window: int = 100,
                 center: bool = False,
                 norm_vars: bool = False) -> None:
        super().__init__()
        self.cmn_window = cmn_window
        self.min_cmn_window = min_cmn_window
        self.center = center
        self.norm_vars = norm_vars

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Tensor of audio of dimension (..., time).
        """
        cmn_waveform = F.sliding_window_cmn(
            waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
        return cmn_waveform
Artyom Astafurov's avatar
Artyom Astafurov committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156


class Vad(torch.nn.Module):
    r"""Voice Activity Detector. Similar to SoX implementation.
    Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
    The algorithm currently uses a simple cepstral power measurement to detect voice,
    so may be fooled by other things, especially music.

    The effect can trim only from the front of the audio,
    so in order to trim from the back, the reverse effect must also be used.

    Args:
        sample_rate (int): Sample rate of audio signal.
        trigger_level (float, optional): The measurement level used to trigger activity detection.
            This may need to be cahnged depending on the noise level, signal level,
            and other characteristics of the input audio. (Default: 7.0)
        trigger_time (float, optional): The time constant (in seconds)
            used to help ignore short bursts of sound. (Default: 0.25)
        search_time (float, optional): The amount of audio (in seconds)
            to search for quieter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 1.0)
        allowed_gap (float, optional): The allowed gap (in seconds) between
            quiteter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 0.25)
        pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
            before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
        boot_time (float, optional) The algorithm (internally) uses adaptive noise
            estimation/reduction in order to detect the start of the wanted audio.
            This option sets the time for the initial noise estimate. (Default: 0.35)
        noise_up_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is increasing. (Default: 0.1)
        noise_down_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is decreasing. (Default: 0.01)
        noise_reduction_amount (float, optional) Amount of noise reduction to use in
            the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
        measure_freq (float, optional) Frequency of the algorithm’s
            processing/measurements. (Default: 20.0)
        measure_duration: (float, optional) Measurement duration.
            (Default: Twice the measurement period; i.e. with overlap.)
        measure_smooth_time (float, optional) Time constant used to smooth
            spectral measurements. (Default: 0.4)
        hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
            at the input to the detector algorithm. (Default: 50.0)
        lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
            at the input to the detector algorithm. (Default: 6000.0)
        hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
            in the detector algorithm. (Default: 150.0)
        lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
            in the detector algorithm. (Default: 2000.0)

moto's avatar
moto committed
1157
1158
    Reference:
        - http://sox.sourceforge.net/sox.html
Artyom Astafurov's avatar
Artyom Astafurov committed
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
    """

    def __init__(self,
                 sample_rate: int,
                 trigger_level: float = 7.0,
                 trigger_time: float = 0.25,
                 search_time: float = 1.0,
                 allowed_gap: float = 0.25,
                 pre_trigger_time: float = 0.0,
                 boot_time: float = .35,
                 noise_up_time: float = .1,
                 noise_down_time: float = .01,
                 noise_reduction_amount: float = 1.35,
                 measure_freq: float = 20.0,
                 measure_duration: Optional[float] = None,
                 measure_smooth_time: float = .4,
                 hp_filter_freq: float = 50.,
                 lp_filter_freq: float = 6000.,
                 hp_lifter_freq: float = 150.,
                 lp_lifter_freq: float = 2000.) -> None:
        super().__init__()

        self.sample_rate = sample_rate
        self.trigger_level = trigger_level
        self.trigger_time = trigger_time
        self.search_time = search_time
        self.allowed_gap = allowed_gap
        self.pre_trigger_time = pre_trigger_time
        self.boot_time = boot_time
        self.noise_up_time = noise_up_time
1189
        self.noise_down_time = noise_down_time
Artyom Astafurov's avatar
Artyom Astafurov committed
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
        self.noise_reduction_amount = noise_reduction_amount
        self.measure_freq = measure_freq
        self.measure_duration = measure_duration
        self.measure_smooth_time = measure_smooth_time
        self.hp_filter_freq = hp_filter_freq
        self.lp_filter_freq = lp_filter_freq
        self.hp_lifter_freq = hp_lifter_freq
        self.lp_lifter_freq = lp_lifter_freq

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
1202
1203
1204
1205
            waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
                Tensor of shape `(channels, time)` is treated as a multi-channel recording
                of the same event and the resulting output will be trimmed to the earliest
                voice activity in any channel.
Artyom Astafurov's avatar
Artyom Astafurov committed
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
        """
        return F.vad(
            waveform=waveform,
            sample_rate=self.sample_rate,
            trigger_level=self.trigger_level,
            trigger_time=self.trigger_time,
            search_time=self.search_time,
            allowed_gap=self.allowed_gap,
            pre_trigger_time=self.pre_trigger_time,
            boot_time=self.boot_time,
            noise_up_time=self.noise_up_time,
1217
            noise_down_time=self.noise_down_time,
Artyom Astafurov's avatar
Artyom Astafurov committed
1218
1219
1220
1221
1222
1223
1224
1225
1226
            noise_reduction_amount=self.noise_reduction_amount,
            measure_freq=self.measure_freq,
            measure_duration=self.measure_duration,
            measure_smooth_time=self.measure_smooth_time,
            hp_filter_freq=self.hp_filter_freq,
            lp_filter_freq=self.lp_filter_freq,
            hp_lifter_freq=self.hp_lifter_freq,
            lp_lifter_freq=self.lp_lifter_freq,
        )
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240


class SpectralCentroid(torch.nn.Module):
    r"""Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        sample_rate (int): Sample rate of audio signal.
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
1241
1242
1243
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
1244
1245

    Example
nateanl's avatar
nateanl committed
1246
1247
1248
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.SpectralCentroid(sample_rate)
        >>> spectral_centroid = transform(waveform)  # (channel, time)
1249
1250
1251
1252
1253
1254
1255
1256
1257
    """
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad']

    def __init__(self,
                 sample_rate: int,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
1258
1259
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
1260
1261
1262
1263
1264
        super(SpectralCentroid, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
1265
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        self.register_buffer('window', window)
        self.pad = pad

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Spectral Centroid of size (..., time).
        """

        return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
                                   self.win_length)
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297


class PitchShift(torch.nn.Module):
    r"""Shift the pitch of a waveform by ``n_steps`` steps.

    Args:
        waveform (Tensor): The input waveform of shape `(..., time)`.
        sample_rate (float): Sample rate of `waveform`.
        n_steps (int): The (fractional) steps to shift `waveform`.
        bins_per_octave (int, optional): The number of steps per octave (Default : ``12``).
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
        win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
        hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4``
            is used (Default: ``None``).
        window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
            If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).

    Example
nateanl's avatar
nateanl committed
1298
1299
1300
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.PitchShift(sample_rate, 4)
        >>> waveform_shift = transform(waveform)  # (channel, time)
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
    """
    __constants__ = ['sample_rate', 'n_steps', 'bins_per_octave', 'n_fft', 'win_length', 'hop_length']

    def __init__(self,
                 sample_rate: int,
                 n_steps: int,
                 bins_per_octave: int = 12,
                 n_fft: int = 512,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
        super(PitchShift, self).__init__()
        self.n_steps = n_steps
        self.bins_per_octave = bins_per_octave
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 4
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: The pitch-shifted audio of shape `(..., time)`.
        """

        return F.pitch_shift(waveform, self.sample_rate, self.n_steps, self.bins_per_octave, self.n_fft,
                             self.win_length, self.hop_length, self.window)