transforms.py 47.9 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
5
from typing import Callable, Optional

David Pollack's avatar
David Pollack committed
6
import torch
7
from torch import Tensor
8
from torchaudio import functional as F
9

Jason Lian's avatar
Jason Lian committed
10

11
12
__all__ = [
    'Spectrogram',
13
    'GriffinLim',
14
    'AmplitudeToDB',
15
    'MelScale',
moto's avatar
moto committed
16
    'InverseMelScale',
17
18
19
20
21
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
22
23
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
24
    'Fade',
25
26
    'FrequencyMasking',
    'TimeMasking',
wanglong001's avatar
wanglong001 committed
27
    'SlidingWindowCmn',
Artyom Astafurov's avatar
Artyom Astafurov committed
28
    'Vad',
29
    'SpectralCentroid',
30
31
    'Vol',
    'ComputeDeltas',
32
33
34
]


35
class Spectrogram(torch.nn.Module):
36
    r"""Create a spectrogram from a audio signal.
37
38

    Args:
39
40
41
42
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
43
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
44
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
45
46
47
48
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
49
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
50
51
52
53
54
55
56
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy Default: ``True``
57
58
59
60
61
62
        return_complex (bool, optional):
            ``return_complex = True``, this function returns the resulting Tensor in
            complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
            dimension for real and imaginary parts. (see ``torch.view_as_real``).
            When ``power`` is provided, the value must be False, as the resulting
            Tensor represents real-valued power.
63
    """
64
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
65

66
67
68
69
70
71
72
73
    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
74
75
76
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
77
78
                 onesided: bool = True,
                 return_complex: bool = False) -> None:
79
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
80
        self.n_fft = n_fft
81
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
82
        # number of frequencies due to onesided=True in torch.stft
83
84
85
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
86
        self.register_buffer('window', window)
87
        self.pad = pad
PCerles's avatar
PCerles committed
88
        self.power = power
89
        self.normalized = normalized
90
91
92
        self.center = center
        self.pad_mode = pad_mode
        self.onesided = onesided
93
        self.return_complex = return_complex
94

95
    def forward(self, waveform: Tensor) -> Tensor:
96
        r"""
97
        Args:
98
            waveform (Tensor): Tensor of audio of dimension (..., time).
99
100

        Returns:
101
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
102
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
103
            Fourier bins, and time is the number of window hops (n_frame).
104
        """
105
106
107
108
109
110
111
112
113
114
115
        return F.spectrogram(
            waveform,
            self.pad,
            self.window,
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.power,
            self.normalized,
            self.center,
            self.pad_mode,
116
117
            self.onesided,
            self.return_complex,
118
        )
119
120


121
122
123
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.

124
    Implementation ported from ``librosa`` [1]_, [2]_, [3]_.
125
126

    Args:
127
128
129
130
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
131
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
132
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
133
        power (float, optional): Exponent for the magnitude spectrogram,
134
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
135
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
136
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
137
            Setting this to 0 recovers the original Griffin-Lim method.
138
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
139
        length (int, optional): Array length of the expected output. (Default: ``None``)
140
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

    References:
        .. [1]
           | McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg,
             and Oriol Nieto.
           | "librosa: Audio and music signal analysis in python."
           | In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

        .. [2]
           | Perraudin, N., Balazs, P., & Søndergaard, P. L.
           | "A fast Griffin-Lim algorithm,"
           | IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
           | Oct. 2013.

        .. [3]
           | D. W. Griffin and J. S. Lim,
           | "Signal estimation from modified short-time Fourier transform,"
           | IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
159
    """
160
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power',
161
162
                     'length', 'momentum', 'rand_init']

163
164
165
166
167
168
169
170
171
172
173
    def __init__(self,
                 n_fft: int = 400,
                 n_iter: int = 32,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: float = 2.,
                 wkwargs: Optional[dict] = None,
                 momentum: float = 0.99,
                 length: Optional[int] = None,
                 rand_init: bool = True) -> None:
174
175
        super(GriffinLim, self).__init__()

176
        assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
177
        assert momentum >= 0, 'momentum={} < 0'.format(momentum)
178
179
180
181
182
183
184
185
186
187
188
189

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

190
191
192
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
193
194
195
            specgram (Tensor):
                A magnitude-only STFT spectrogram of dimension (..., freq, frames)
                where freq is ``n_fft // 2 + 1``.
196
197
198
199
200

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
201
                            self.n_iter, self.momentum, self.length, self.rand_init)
202
203


204
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
205
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
206

207
    This output depends on the maximum value in the input tensor, and so
208
209
210
211
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
212
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
213
            power being the elementwise square of the magnitude. (Default: ``'power'``)
214
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
215
            is 80. (Default: ``None``)
216
217
218
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

219
    def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
220
        super(AmplitudeToDB, self).__init__()
221
        self.stype = stype
222
223
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
224
        self.top_db = top_db
225
226
227
228
229
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

230
    def forward(self, x: Tensor) -> Tensor:
231
        r"""Numerically stable implementation from Librosa.
moto's avatar
moto committed
232
233

        https://librosa.org/doc/latest/generated/librosa.amplitude_to_db.html
234
235

        Args:
236
            x (Tensor): Input tensor before being converted to decibel scale.
237
238

        Returns:
239
            Tensor: Output tensor in decibel scale.
240
        """
241
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
242
243


244
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
245
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
246
    matrix.  This uses triangular filter banks.
247

248
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
249

250
    Args:
251
252
253
254
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
255
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
256
            if None is given.  See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
257
258
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
259
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
260
    """
261
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
262

263
264
265
266
267
    def __init__(self,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
268
                 n_stft: Optional[int] = None,
269
270
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
271
        super(MelScale, self).__init__()
272
        self.n_mels = n_mels
273
274
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
275
        self.f_min = f_min
276
        self.norm = norm
277
        self.mel_scale = mel_scale
278

279
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
280

281
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
282
283
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
            self.mel_scale)
284
        self.register_buffer('fb', fb)
285

286
    def forward(self, specgram: Tensor) -> Tensor:
287
288
        r"""
        Args:
289
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
290
291

        Returns:
292
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
293
        """
Vincent QB's avatar
Vincent QB committed
294
295
296

        # pack batch
        shape = specgram.size()
297
        specgram = specgram.reshape(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
298

299
        if self.fb.numel() == 0:
300
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
301
302
                                        self.n_mels, self.sample_rate, self.norm,
                                        self.mel_scale)
303
304
305
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
306

307
308
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
309
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
310
311

        # unpack batch
312
        mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
313

314
        return mel_specgram
315

316

moto's avatar
moto committed
317
318
319
320
321
322
323
324
325
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
326
327
328
329
330
331
332
333
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
334
335
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
336
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
moto's avatar
moto committed
337
338
339
340
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

341
342
343
344
345
346
347
348
349
    def __init__(self,
                 n_stft: int,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 max_iter: int = 100000,
                 tolerance_loss: float = 1e-5,
                 tolerance_change: float = 1e-8,
350
                 sgdargs: Optional[dict] = None,
351
352
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
moto's avatar
moto committed
353
354
355
356
357
358
359
360
361
362
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

363
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
moto's avatar
moto committed
364

365
366
        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm,
                                mel_scale)
moto's avatar
moto committed
367
368
        self.register_buffer('fb', fb)

369
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
370
371
        r"""
        Args:
372
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
373
374

        Returns:
375
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


415
class MelSpectrogram(torch.nn.Module):
416
417
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
418

419
    Sources
420
421
422
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
423

424
    Args:
425
426
427
428
429
430
431
432
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
433
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
434
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
435
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
436
437
438
439
440
441
442
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
443
444
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
445
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
446

447
    Example
448
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
449
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
450
    """
451
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
452

453
454
455
456
457
458
459
460
461
462
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 pad: int = 0,
                 n_mels: int = 128,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
463
464
                 power: Optional[float] = 2.,
                 normalized: bool = False,
465
466
467
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
468
                 onesided: bool = True,
469
470
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
471
472
473
474
475
476
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
477
478
        self.power = power
        self.normalized = normalized
479
        self.n_mels = n_mels  # number of mel frequency bins
480
        self.f_max = f_max
481
482
483
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
484
                                       pad=self.pad, window_fn=window_fn, power=self.power,
485
486
                                       normalized=self.normalized, wkwargs=wkwargs,
                                       center=center, pad_mode=pad_mode, onesided=onesided)
487
488
489
490
491
492
493
494
495
        self.mel_scale = MelScale(
            self.n_mels,
            self.sample_rate,
            self.f_min,
            self.f_max,
            self.n_fft // 2 + 1,
            norm,
            mel_scale
        )
496

497
    def forward(self, waveform: Tensor) -> Tensor:
498
        r"""
499
        Args:
500
            waveform (Tensor): Tensor of audio of dimension (..., time).
501
502

        Returns:
503
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
504
        """
505
506
507
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
508
509


510
class MFCC(torch.nn.Module):
511
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
512

513
514
515
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
516

517
518
519
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
520

521
    Args:
522
523
524
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
525
        norm (str, optional): norm to use. (Default: ``'ortho'``)
526
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
527
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
528
    """
529
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
530

531
532
533
534
535
536
537
    def __init__(self,
                 sample_rate: int = 16000,
                 n_mfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_mels: bool = False,
                 melkwargs: Optional[dict] = None) -> None:
538
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
539
540
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
541
            raise ValueError('DCT type not supported: {}'.format(dct_type))
542
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
543
544
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
545
        self.norm = norm
546
        self.top_db = 80.0
547
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
548
549

        if melkwargs is not None:
550
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
551
        else:
552
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
553
554
555

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
556
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
557
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
558
559
        self.log_mels = log_mels

560
    def forward(self, waveform: Tensor) -> Tensor:
561
        r"""
PCerles's avatar
PCerles committed
562
        Args:
563
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
564
565

        Returns:
566
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
567
        """
568
        mel_specgram = self.MelSpectrogram(waveform)
569
570
        if self.log_mels:
            log_offset = 1e-6
571
            mel_specgram = torch.log(mel_specgram + log_offset)
572
        else:
573
            mel_specgram = self.amplitude_to_DB(mel_specgram)
Vincent QB's avatar
Vincent QB committed
574

575
576
        # (..., channel, n_mels, time).transpose(...) dot (n_mels, n_mfcc)
        # -> (..., channel, time, n_mfcc).transpose(...)
577
        mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1)
578
        return mfcc
579
580


581
class MuLawEncoding(torch.nn.Module):
582
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
583
584
585
586
587
588
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
589
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
590
    """
591
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
592

593
    def __init__(self, quantization_channels: int = 256) -> None:
594
        super(MuLawEncoding, self).__init__()
595
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
596

597
    def forward(self, x: Tensor) -> Tensor:
598
        r"""
David Pollack's avatar
David Pollack committed
599
        Args:
600
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
601
602

        Returns:
603
            x_mu (Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
604
        """
605
        return F.mu_law_encoding(x, self.quantization_channels)
606

Soumith Chintala's avatar
Soumith Chintala committed
607

608
class MuLawDecoding(torch.nn.Module):
609
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
610
611
612
613
614
615
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
616
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
617
    """
618
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
619

620
    def __init__(self, quantization_channels: int = 256) -> None:
621
        super(MuLawDecoding, self).__init__()
622
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
623

624
    def forward(self, x_mu: Tensor) -> Tensor:
625
        r"""
David Pollack's avatar
David Pollack committed
626
        Args:
627
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
628
629

        Returns:
630
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
631
        """
632
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
633
634
635


class Resample(torch.nn.Module):
636
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
637
638

    Args:
639
640
641
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
642
    """
643

644
645
646
647
    def __init__(self,
                 orig_freq: int = 16000,
                 new_freq: int = 16000,
                 resampling_method: str = 'sinc_interpolation') -> None:
jamarshon's avatar
jamarshon committed
648
649
650
651
652
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

653
    def forward(self, waveform: Tensor) -> Tensor:
654
        r"""
jamarshon's avatar
jamarshon committed
655
        Args:
656
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
657
658

        Returns:
659
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
660
661
        """
        if self.resampling_method == 'sinc_interpolation':
662
            return F.resample(waveform, self.orig_freq, self.new_freq)
jamarshon's avatar
jamarshon committed
663

664
        raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
Vincent QB's avatar
Vincent QB committed
665
666


667
class ComplexNorm(torch.nn.Module):
668
669
    r"""Compute the norm of complex tensor input.

670
    Args:
671
        power (float, optional): Power of the norm. (Default: to ``1.0``)
672
673
674
    """
    __constants__ = ['power']

675
    def __init__(self, power: float = 1.0) -> None:
676
677
678
        super(ComplexNorm, self).__init__()
        self.power = power

679
    def forward(self, complex_tensor: Tensor) -> Tensor:
680
681
        r"""
        Args:
682
683
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

684
        Returns:
685
            Tensor: norm of the input tensor, shape of `(..., )`.
686
687
688
689
        """
        return F.complex_norm(complex_tensor, self.power)


690
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
691
692
693
694
695
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
696
697
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
698
699
700
    """
    __constants__ = ['win_length']

701
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
702
703
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
704
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
705

706
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
707
708
        r"""
        Args:
709
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
710
711

        Returns:
712
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
713
714
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
715
716


717
class TimeStretch(torch.nn.Module):
718
719
720
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
721
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
722
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
723
        fixed_rate (float or None, optional): rate to speed up or slow down by.
724
725
726
727
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

728
729
730
731
    def __init__(self,
                 hop_length: Optional[int] = None,
                 n_freq: int = 201,
                 fixed_rate: Optional[float] = None) -> None:
732
733
734
735
736
737
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
738
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
739

740
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
741
742
        r"""
        Args:
743
744
745
            complex_specgrams (Tensor):
                Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
                or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
746
747
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
748
749

        Returns:
750
751
752
            Tensor:
                Stretched spectrogram. The resulting tensor is of the same dtype as the input
                spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
753
754
        """
        if overriding_rate is None:
755
756
757
            if self.fixed_rate is None:
                raise ValueError(
                    "If no fixed_rate is specified, must pass a valid rate to the forward method.")
758
759
760
            rate = self.fixed_rate
        else:
            rate = overriding_rate
761
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
762
763


Tomás Osório's avatar
Tomás Osório committed
764
765
766
767
768
769
770
771
772
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
773

774
775
776
777
    def __init__(self,
                 fade_in_len: int = 0,
                 fade_out_len: int = 0,
                 fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
778
779
780
781
782
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

783
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
784
785
        r"""
        Args:
786
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
787
788

        Returns:
789
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
790
791
        """
        waveform_length = waveform.size()[-1]
moto's avatar
moto committed
792
793
794
        device = waveform.device
        return self._fade_in(waveform_length).to(device) * \
            self._fade_out(waveform_length).to(device) * waveform
Tomás Osório's avatar
Tomás Osório committed
795

796
    def _fade_in(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

817
    def _fade_out(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


839
class _AxisMasking(torch.nn.Module):
840
841
    r"""Apply masking to a spectrogram.

842
    Args:
843
844
845
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
846
            This option is applicable only when the input tensor is 4D.
847
848
849
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

850
    def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
851
852
853
854
855
856

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

857
    def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
858
859
        r"""
        Args:
860
            specgram (Tensor): Tensor of dimension (..., freq, time).
861
            mask_value (float): Value to assign to the masked columns.
862
863

        Returns:
864
            Tensor: Masked spectrogram of dimensions (..., freq, time).
865
866
867
868
869
        """
        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
870
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
871
872
873


class FrequencyMasking(_AxisMasking):
874
875
    r"""Apply masking to a spectrogram in the frequency domain.

876
877
878
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
879
880
881
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
882
    """
883

884
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
885
886
887
888
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
889
890
    r"""Apply masking to a spectrogram in the time domain.

891
892
893
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
894
895
896
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
897
    """
898

899
    def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
900
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Tomás Osório's avatar
Tomás Osório committed
901
902
903
904
905
906
907


class Vol(torch.nn.Module):
    r"""Add a volume to an waveform.

    Args:
        gain (float): Interpreted according to the given gain_type:
Vincent QB's avatar
Vincent QB committed
908
909
910
911
            If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio.
            If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared).
            If ``gain_type`` = ``db``, ``gain`` is in decibels.
        gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
Tomás Osório's avatar
Tomás Osório committed
912
913
    """

914
    def __init__(self, gain: float, gain_type: str = 'amplitude'):
Tomás Osório's avatar
Tomás Osório committed
915
916
917
918
919
920
921
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

        if gain_type in ['amplitude', 'power'] and gain < 0:
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

922
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
923
924
        r"""
        Args:
925
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
926
927

        Returns:
928
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
929
930
931
932
933
934
935
936
937
938
939
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)
wanglong001's avatar
wanglong001 committed
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976


class SlidingWindowCmn(torch.nn.Module):
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
    """

    def __init__(self,
                 cmn_window: int = 600,
                 min_cmn_window: int = 100,
                 center: bool = False,
                 norm_vars: bool = False) -> None:
        super().__init__()
        self.cmn_window = cmn_window
        self.min_cmn_window = min_cmn_window
        self.center = center
        self.norm_vars = norm_vars

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Tensor of audio of dimension (..., time).
        """
        cmn_waveform = F.sliding_window_cmn(
            waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
        return cmn_waveform
Artyom Astafurov's avatar
Artyom Astafurov committed
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058


class Vad(torch.nn.Module):
    r"""Voice Activity Detector. Similar to SoX implementation.
    Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
    The algorithm currently uses a simple cepstral power measurement to detect voice,
    so may be fooled by other things, especially music.

    The effect can trim only from the front of the audio,
    so in order to trim from the back, the reverse effect must also be used.

    Args:
        sample_rate (int): Sample rate of audio signal.
        trigger_level (float, optional): The measurement level used to trigger activity detection.
            This may need to be cahnged depending on the noise level, signal level,
            and other characteristics of the input audio. (Default: 7.0)
        trigger_time (float, optional): The time constant (in seconds)
            used to help ignore short bursts of sound. (Default: 0.25)
        search_time (float, optional): The amount of audio (in seconds)
            to search for quieter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 1.0)
        allowed_gap (float, optional): The allowed gap (in seconds) between
            quiteter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 0.25)
        pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
            before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
        boot_time (float, optional) The algorithm (internally) uses adaptive noise
            estimation/reduction in order to detect the start of the wanted audio.
            This option sets the time for the initial noise estimate. (Default: 0.35)
        noise_up_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is increasing. (Default: 0.1)
        noise_down_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is decreasing. (Default: 0.01)
        noise_reduction_amount (float, optional) Amount of noise reduction to use in
            the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
        measure_freq (float, optional) Frequency of the algorithm’s
            processing/measurements. (Default: 20.0)
        measure_duration: (float, optional) Measurement duration.
            (Default: Twice the measurement period; i.e. with overlap.)
        measure_smooth_time (float, optional) Time constant used to smooth
            spectral measurements. (Default: 0.4)
        hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
            at the input to the detector algorithm. (Default: 50.0)
        lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
            at the input to the detector algorithm. (Default: 6000.0)
        hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
            in the detector algorithm. (Default: 150.0)
        lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
            in the detector algorithm. (Default: 2000.0)

    References:
        http://sox.sourceforge.net/sox.html
    """

    def __init__(self,
                 sample_rate: int,
                 trigger_level: float = 7.0,
                 trigger_time: float = 0.25,
                 search_time: float = 1.0,
                 allowed_gap: float = 0.25,
                 pre_trigger_time: float = 0.0,
                 boot_time: float = .35,
                 noise_up_time: float = .1,
                 noise_down_time: float = .01,
                 noise_reduction_amount: float = 1.35,
                 measure_freq: float = 20.0,
                 measure_duration: Optional[float] = None,
                 measure_smooth_time: float = .4,
                 hp_filter_freq: float = 50.,
                 lp_filter_freq: float = 6000.,
                 hp_lifter_freq: float = 150.,
                 lp_lifter_freq: float = 2000.) -> None:
        super().__init__()

        self.sample_rate = sample_rate
        self.trigger_level = trigger_level
        self.trigger_time = trigger_time
        self.search_time = search_time
        self.allowed_gap = allowed_gap
        self.pre_trigger_time = pre_trigger_time
        self.boot_time = boot_time
        self.noise_up_time = noise_up_time
1059
        self.noise_down_time = noise_down_time
Artyom Astafurov's avatar
Artyom Astafurov committed
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
        self.noise_reduction_amount = noise_reduction_amount
        self.measure_freq = measure_freq
        self.measure_duration = measure_duration
        self.measure_smooth_time = measure_smooth_time
        self.hp_filter_freq = hp_filter_freq
        self.lp_filter_freq = lp_filter_freq
        self.hp_lifter_freq = hp_lifter_freq
        self.lp_lifter_freq = lp_lifter_freq

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension `(..., time)`
        """
        return F.vad(
            waveform=waveform,
            sample_rate=self.sample_rate,
            trigger_level=self.trigger_level,
            trigger_time=self.trigger_time,
            search_time=self.search_time,
            allowed_gap=self.allowed_gap,
            pre_trigger_time=self.pre_trigger_time,
            boot_time=self.boot_time,
            noise_up_time=self.noise_up_time,
1084
            noise_down_time=self.noise_down_time,
Artyom Astafurov's avatar
Artyom Astafurov committed
1085
1086
1087
1088
1089
1090
1091
1092
1093
            noise_reduction_amount=self.noise_reduction_amount,
            measure_freq=self.measure_freq,
            measure_duration=self.measure_duration,
            measure_smooth_time=self.measure_smooth_time,
            hp_filter_freq=self.hp_filter_freq,
            lp_filter_freq=self.lp_filter_freq,
            hp_lifter_freq=self.hp_lifter_freq,
            lp_lifter_freq=self.lp_lifter_freq,
        )
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107


class SpectralCentroid(torch.nn.Module):
    r"""Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        sample_rate (int): Sample rate of audio signal.
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
1108
1109
1110
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
        >>> spectral_centroid = transforms.SpectralCentroid(sample_rate)(waveform)  # (channel, time)
    """
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad']

    def __init__(self,
                 sample_rate: int,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
1124
1125
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
1126
1127
1128
1129
1130
        super(SpectralCentroid, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
1131
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
        self.register_buffer('window', window)
        self.pad = pad

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Spectral Centroid of size (..., time).
        """

        return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
                                   self.win_length)