transforms.py 48.4 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
import warnings
5
6
from typing import Callable, Optional

David Pollack's avatar
David Pollack committed
7
import torch
8
from torch import Tensor
9
from torchaudio import functional as F
10

Jason Lian's avatar
Jason Lian committed
11

12
13
__all__ = [
    'Spectrogram',
14
    'GriffinLim',
15
    'AmplitudeToDB',
16
    'MelScale',
moto's avatar
moto committed
17
    'InverseMelScale',
18
19
20
21
22
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
23
24
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
25
    'Fade',
26
27
    'FrequencyMasking',
    'TimeMasking',
wanglong001's avatar
wanglong001 committed
28
    'SlidingWindowCmn',
Artyom Astafurov's avatar
Artyom Astafurov committed
29
    'Vad',
30
    'SpectralCentroid',
31
32
    'Vol',
    'ComputeDeltas',
33
34
35
]


36
class Spectrogram(torch.nn.Module):
37
    r"""Create a spectrogram from a audio signal.
38
39

    Args:
40
41
42
43
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
44
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
45
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
46
47
48
49
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
50
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
51
52
53
54
55
56
57
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy Default: ``True``
58
59
60
61
62
63
        return_complex (bool, optional):
            ``return_complex = True``, this function returns the resulting Tensor in
            complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
            dimension for real and imaginary parts. (see ``torch.view_as_real``).
            When ``power`` is provided, the value must be False, as the resulting
            Tensor represents real-valued power.
64
    """
65
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
66

67
68
69
70
71
72
73
74
    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
75
76
77
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
78
79
                 onesided: bool = True,
                 return_complex: bool = False) -> None:
80
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
81
        self.n_fft = n_fft
82
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
83
        # number of frequencies due to onesided=True in torch.stft
84
85
86
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
87
        self.register_buffer('window', window)
88
        self.pad = pad
PCerles's avatar
PCerles committed
89
        self.power = power
90
        self.normalized = normalized
91
92
93
        self.center = center
        self.pad_mode = pad_mode
        self.onesided = onesided
94
        self.return_complex = return_complex
95

96
    def forward(self, waveform: Tensor) -> Tensor:
97
        r"""
98
        Args:
99
            waveform (Tensor): Tensor of audio of dimension (..., time).
100
101

        Returns:
102
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
103
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
104
            Fourier bins, and time is the number of window hops (n_frame).
105
        """
106
107
108
109
110
111
112
113
114
115
116
        return F.spectrogram(
            waveform,
            self.pad,
            self.window,
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.power,
            self.normalized,
            self.center,
            self.pad_mode,
117
118
            self.onesided,
            self.return_complex,
119
        )
120
121


122
123
124
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.

125
    Implementation ported from ``librosa`` [1]_, [2]_, [3]_.
126
127

    Args:
128
129
130
131
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
132
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
133
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
134
        power (float, optional): Exponent for the magnitude spectrogram,
135
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
136
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
137
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
138
            Setting this to 0 recovers the original Griffin-Lim method.
139
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
140
        length (int, optional): Array length of the expected output. (Default: ``None``)
141
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

    References:
        .. [1]
           | McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg,
             and Oriol Nieto.
           | "librosa: Audio and music signal analysis in python."
           | In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

        .. [2]
           | Perraudin, N., Balazs, P., & Søndergaard, P. L.
           | "A fast Griffin-Lim algorithm,"
           | IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
           | Oct. 2013.

        .. [3]
           | D. W. Griffin and J. S. Lim,
           | "Signal estimation from modified short-time Fourier transform,"
           | IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
160
    """
161
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power',
162
163
                     'length', 'momentum', 'rand_init']

164
165
166
167
168
169
170
171
172
173
174
    def __init__(self,
                 n_fft: int = 400,
                 n_iter: int = 32,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: float = 2.,
                 wkwargs: Optional[dict] = None,
                 momentum: float = 0.99,
                 length: Optional[int] = None,
                 rand_init: bool = True) -> None:
175
176
        super(GriffinLim, self).__init__()

177
        assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
178
        assert momentum >= 0, 'momentum={} < 0'.format(momentum)
179
180
181
182
183
184
185
186
187
188
189
190

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

191
192
193
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
194
195
196
            specgram (Tensor):
                A magnitude-only STFT spectrogram of dimension (..., freq, frames)
                where freq is ``n_fft // 2 + 1``.
197
198
199
200
201

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
202
                            self.n_iter, self.momentum, self.length, self.rand_init)
203
204


205
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
206
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
207

208
    This output depends on the maximum value in the input tensor, and so
209
210
211
212
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
213
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
214
            power being the elementwise square of the magnitude. (Default: ``'power'``)
215
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
216
            is 80. (Default: ``None``)
217
218
219
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

220
    def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
221
        super(AmplitudeToDB, self).__init__()
222
        self.stype = stype
223
224
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
225
        self.top_db = top_db
226
227
228
229
230
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

231
    def forward(self, x: Tensor) -> Tensor:
232
        r"""Numerically stable implementation from Librosa.
moto's avatar
moto committed
233
234

        https://librosa.org/doc/latest/generated/librosa.amplitude_to_db.html
235
236

        Args:
237
            x (Tensor): Input tensor before being converted to decibel scale.
238
239

        Returns:
240
            Tensor: Output tensor in decibel scale.
241
        """
242
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
243
244


245
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
246
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
247
    matrix.  This uses triangular filter banks.
248

249
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
250

251
    Args:
252
253
254
255
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
256
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
257
            if None is given.  See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
258
259
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
260
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
261
    """
262
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
263

264
265
266
267
268
    def __init__(self,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
269
                 n_stft: Optional[int] = None,
270
271
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
272
        super(MelScale, self).__init__()
273
        self.n_mels = n_mels
274
275
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
276
        self.f_min = f_min
277
        self.norm = norm
278
        self.mel_scale = mel_scale
279

280
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
281

282
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
283
284
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm,
            self.mel_scale)
285
        self.register_buffer('fb', fb)
286

287
    def forward(self, specgram: Tensor) -> Tensor:
288
289
        r"""
        Args:
290
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
291
292

        Returns:
293
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
294
        """
Vincent QB's avatar
Vincent QB committed
295
296
297

        # pack batch
        shape = specgram.size()
298
        specgram = specgram.reshape(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
299

300
        if self.fb.numel() == 0:
301
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
302
303
                                        self.n_mels, self.sample_rate, self.norm,
                                        self.mel_scale)
304
305
306
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
307

308
309
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
310
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
311
312

        # unpack batch
313
        mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
314

315
        return mel_specgram
316

317

moto's avatar
moto committed
318
319
320
321
322
323
324
325
326
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
327
328
329
330
331
332
333
334
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
335
336
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
337
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
moto's avatar
moto committed
338
339
340
341
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

342
343
344
345
346
347
348
349
350
    def __init__(self,
                 n_stft: int,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 max_iter: int = 100000,
                 tolerance_loss: float = 1e-5,
                 tolerance_change: float = 1e-8,
351
                 sgdargs: Optional[dict] = None,
352
353
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
moto's avatar
moto committed
354
355
356
357
358
359
360
361
362
363
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

364
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
moto's avatar
moto committed
365

366
367
        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm,
                                mel_scale)
moto's avatar
moto committed
368
369
        self.register_buffer('fb', fb)

370
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
371
372
        r"""
        Args:
373
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
374
375

        Returns:
376
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


416
class MelSpectrogram(torch.nn.Module):
417
418
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
419

420
    Sources
421
422
423
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
424

425
    Args:
426
427
428
429
430
431
432
433
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
434
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
435
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
436
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
437
438
439
440
441
442
443
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
444
445
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
446
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
447

448
    Example
449
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
450
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
451
    """
452
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
453

454
455
456
457
458
459
460
461
462
463
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 pad: int = 0,
                 n_mels: int = 128,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
464
465
                 power: Optional[float] = 2.,
                 normalized: bool = False,
466
467
468
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
469
                 onesided: bool = True,
470
471
                 norm: Optional[str] = None,
                 mel_scale: str = "htk") -> None:
472
473
474
475
476
477
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
478
479
        self.power = power
        self.normalized = normalized
480
        self.n_mels = n_mels  # number of mel frequency bins
481
        self.f_max = f_max
482
483
484
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
485
                                       pad=self.pad, window_fn=window_fn, power=self.power,
486
487
                                       normalized=self.normalized, wkwargs=wkwargs,
                                       center=center, pad_mode=pad_mode, onesided=onesided)
488
489
490
491
492
493
494
495
496
        self.mel_scale = MelScale(
            self.n_mels,
            self.sample_rate,
            self.f_min,
            self.f_max,
            self.n_fft // 2 + 1,
            norm,
            mel_scale
        )
497

498
    def forward(self, waveform: Tensor) -> Tensor:
499
        r"""
500
        Args:
501
            waveform (Tensor): Tensor of audio of dimension (..., time).
502
503

        Returns:
504
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
505
        """
506
507
508
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
509
510


511
class MFCC(torch.nn.Module):
512
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
513

514
515
516
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
517

518
519
520
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
521

522
    Args:
523
524
525
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
526
        norm (str, optional): norm to use. (Default: ``'ortho'``)
527
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
528
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
529
    """
530
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
531

532
533
534
535
536
537
538
    def __init__(self,
                 sample_rate: int = 16000,
                 n_mfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_mels: bool = False,
                 melkwargs: Optional[dict] = None) -> None:
539
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
540
541
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
542
            raise ValueError('DCT type not supported: {}'.format(dct_type))
543
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
544
545
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
546
        self.norm = norm
547
        self.top_db = 80.0
548
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
549
550

        if melkwargs is not None:
551
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
552
        else:
553
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
554
555
556

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
557
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
558
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
559
560
        self.log_mels = log_mels

561
    def forward(self, waveform: Tensor) -> Tensor:
562
        r"""
PCerles's avatar
PCerles committed
563
        Args:
564
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
565
566

        Returns:
567
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
568
        """
569
        mel_specgram = self.MelSpectrogram(waveform)
570
571
        if self.log_mels:
            log_offset = 1e-6
572
            mel_specgram = torch.log(mel_specgram + log_offset)
573
        else:
574
            mel_specgram = self.amplitude_to_DB(mel_specgram)
Vincent QB's avatar
Vincent QB committed
575

576
577
        # (..., channel, n_mels, time).transpose(...) dot (n_mels, n_mfcc)
        # -> (..., channel, time, n_mfcc).transpose(...)
578
        mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1)
579
        return mfcc
580
581


582
class MuLawEncoding(torch.nn.Module):
583
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
584
585
586
587
588
589
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
590
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
591
    """
592
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
593

594
    def __init__(self, quantization_channels: int = 256) -> None:
595
        super(MuLawEncoding, self).__init__()
596
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
597

598
    def forward(self, x: Tensor) -> Tensor:
599
        r"""
David Pollack's avatar
David Pollack committed
600
        Args:
601
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
602
603

        Returns:
604
            x_mu (Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
605
        """
606
        return F.mu_law_encoding(x, self.quantization_channels)
607

Soumith Chintala's avatar
Soumith Chintala committed
608

609
class MuLawDecoding(torch.nn.Module):
610
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
611
612
613
614
615
616
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
617
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
618
    """
619
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
620

621
    def __init__(self, quantization_channels: int = 256) -> None:
622
        super(MuLawDecoding, self).__init__()
623
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
624

625
    def forward(self, x_mu: Tensor) -> Tensor:
626
        r"""
David Pollack's avatar
David Pollack committed
627
        Args:
628
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
629
630

        Returns:
631
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
632
        """
633
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
634
635
636


class Resample(torch.nn.Module):
637
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
638
639

    Args:
640
641
642
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
643
    """
644

645
646
647
648
    def __init__(self,
                 orig_freq: int = 16000,
                 new_freq: int = 16000,
                 resampling_method: str = 'sinc_interpolation') -> None:
jamarshon's avatar
jamarshon committed
649
650
651
652
653
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

654
    def forward(self, waveform: Tensor) -> Tensor:
655
        r"""
jamarshon's avatar
jamarshon committed
656
        Args:
657
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
658
659

        Returns:
660
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
661
662
        """
        if self.resampling_method == 'sinc_interpolation':
663
            return F.resample(waveform, self.orig_freq, self.new_freq)
jamarshon's avatar
jamarshon committed
664

665
        raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
Vincent QB's avatar
Vincent QB committed
666
667


668
class ComplexNorm(torch.nn.Module):
669
670
    r"""Compute the norm of complex tensor input.

671
    Args:
672
        power (float, optional): Power of the norm. (Default: to ``1.0``)
673
674
675
    """
    __constants__ = ['power']

676
    def __init__(self, power: float = 1.0) -> None:
677
678
679
680
681
682
683
684
        warnings.warn(
            'torchaudio.transforms.ComplexNorm has been deprecated '
            'and will be removed from future release.'
            'Please convert the input Tensor to complex type with `torch.view_as_complex` then '
            'use `torch.abs` and `torch.angle`. '
            'Please refer to https://github.com/pytorch/audio/issues/1337 '
            "for more details about torchaudio's plan to migrate to native complex type."
        )
685
686
687
        super(ComplexNorm, self).__init__()
        self.power = power

688
    def forward(self, complex_tensor: Tensor) -> Tensor:
689
690
        r"""
        Args:
691
692
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

693
        Returns:
694
            Tensor: norm of the input tensor, shape of `(..., )`.
695
696
697
698
        """
        return F.complex_norm(complex_tensor, self.power)


699
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
700
701
702
703
704
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
705
706
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
707
708
709
    """
    __constants__ = ['win_length']

710
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
711
712
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
713
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
714

715
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
716
717
        r"""
        Args:
718
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
719
720

        Returns:
721
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
722
723
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
724
725


726
class TimeStretch(torch.nn.Module):
727
728
729
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
730
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
731
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
732
        fixed_rate (float or None, optional): rate to speed up or slow down by.
733
734
735
736
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

737
738
739
740
    def __init__(self,
                 hop_length: Optional[int] = None,
                 n_freq: int = 201,
                 fixed_rate: Optional[float] = None) -> None:
741
742
743
744
745
746
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
747
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
748

749
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
750
751
        r"""
        Args:
752
753
754
            complex_specgrams (Tensor):
                Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
                or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
755
756
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
757
758

        Returns:
759
760
761
            Tensor:
                Stretched spectrogram. The resulting tensor is of the same dtype as the input
                spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
762
763
        """
        if overriding_rate is None:
764
765
766
            if self.fixed_rate is None:
                raise ValueError(
                    "If no fixed_rate is specified, must pass a valid rate to the forward method.")
767
768
769
            rate = self.fixed_rate
        else:
            rate = overriding_rate
770
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
771
772


Tomás Osório's avatar
Tomás Osório committed
773
774
775
776
777
778
779
780
781
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
782

783
784
785
786
    def __init__(self,
                 fade_in_len: int = 0,
                 fade_out_len: int = 0,
                 fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
787
788
789
790
791
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

792
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
793
794
        r"""
        Args:
795
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
796
797

        Returns:
798
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
799
800
        """
        waveform_length = waveform.size()[-1]
moto's avatar
moto committed
801
802
803
        device = waveform.device
        return self._fade_in(waveform_length).to(device) * \
            self._fade_out(waveform_length).to(device) * waveform
Tomás Osório's avatar
Tomás Osório committed
804

805
    def _fade_in(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

826
    def _fade_out(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


848
class _AxisMasking(torch.nn.Module):
849
850
    r"""Apply masking to a spectrogram.

851
    Args:
852
853
854
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
855
            This option is applicable only when the input tensor is 4D.
856
857
858
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

859
    def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
860
861
862
863
864
865

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

866
    def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
867
868
        r"""
        Args:
869
            specgram (Tensor): Tensor of dimension (..., freq, time).
870
            mask_value (float): Value to assign to the masked columns.
871
872

        Returns:
873
            Tensor: Masked spectrogram of dimensions (..., freq, time).
874
875
876
877
878
        """
        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
879
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
880
881
882


class FrequencyMasking(_AxisMasking):
883
884
    r"""Apply masking to a spectrogram in the frequency domain.

885
886
887
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
888
889
890
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
891
    """
892

893
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
894
895
896
897
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
898
899
    r"""Apply masking to a spectrogram in the time domain.

900
901
902
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
903
904
905
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
906
    """
907

908
    def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
909
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Tomás Osório's avatar
Tomás Osório committed
910
911
912
913
914
915
916


class Vol(torch.nn.Module):
    r"""Add a volume to an waveform.

    Args:
        gain (float): Interpreted according to the given gain_type:
Vincent QB's avatar
Vincent QB committed
917
918
919
920
            If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio.
            If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared).
            If ``gain_type`` = ``db``, ``gain`` is in decibels.
        gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
Tomás Osório's avatar
Tomás Osório committed
921
922
    """

923
    def __init__(self, gain: float, gain_type: str = 'amplitude'):
Tomás Osório's avatar
Tomás Osório committed
924
925
926
927
928
929
930
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

        if gain_type in ['amplitude', 'power'] and gain < 0:
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

931
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
932
933
        r"""
        Args:
934
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
935
936

        Returns:
937
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
938
939
940
941
942
943
944
945
946
947
948
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)
wanglong001's avatar
wanglong001 committed
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985


class SlidingWindowCmn(torch.nn.Module):
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
    """

    def __init__(self,
                 cmn_window: int = 600,
                 min_cmn_window: int = 100,
                 center: bool = False,
                 norm_vars: bool = False) -> None:
        super().__init__()
        self.cmn_window = cmn_window
        self.min_cmn_window = min_cmn_window
        self.center = center
        self.norm_vars = norm_vars

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Tensor of audio of dimension (..., time).
        """
        cmn_waveform = F.sliding_window_cmn(
            waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
        return cmn_waveform
Artyom Astafurov's avatar
Artyom Astafurov committed
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067


class Vad(torch.nn.Module):
    r"""Voice Activity Detector. Similar to SoX implementation.
    Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
    The algorithm currently uses a simple cepstral power measurement to detect voice,
    so may be fooled by other things, especially music.

    The effect can trim only from the front of the audio,
    so in order to trim from the back, the reverse effect must also be used.

    Args:
        sample_rate (int): Sample rate of audio signal.
        trigger_level (float, optional): The measurement level used to trigger activity detection.
            This may need to be cahnged depending on the noise level, signal level,
            and other characteristics of the input audio. (Default: 7.0)
        trigger_time (float, optional): The time constant (in seconds)
            used to help ignore short bursts of sound. (Default: 0.25)
        search_time (float, optional): The amount of audio (in seconds)
            to search for quieter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 1.0)
        allowed_gap (float, optional): The allowed gap (in seconds) between
            quiteter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 0.25)
        pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
            before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
        boot_time (float, optional) The algorithm (internally) uses adaptive noise
            estimation/reduction in order to detect the start of the wanted audio.
            This option sets the time for the initial noise estimate. (Default: 0.35)
        noise_up_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is increasing. (Default: 0.1)
        noise_down_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is decreasing. (Default: 0.01)
        noise_reduction_amount (float, optional) Amount of noise reduction to use in
            the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
        measure_freq (float, optional) Frequency of the algorithm’s
            processing/measurements. (Default: 20.0)
        measure_duration: (float, optional) Measurement duration.
            (Default: Twice the measurement period; i.e. with overlap.)
        measure_smooth_time (float, optional) Time constant used to smooth
            spectral measurements. (Default: 0.4)
        hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
            at the input to the detector algorithm. (Default: 50.0)
        lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
            at the input to the detector algorithm. (Default: 6000.0)
        hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
            in the detector algorithm. (Default: 150.0)
        lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
            in the detector algorithm. (Default: 2000.0)

    References:
        http://sox.sourceforge.net/sox.html
    """

    def __init__(self,
                 sample_rate: int,
                 trigger_level: float = 7.0,
                 trigger_time: float = 0.25,
                 search_time: float = 1.0,
                 allowed_gap: float = 0.25,
                 pre_trigger_time: float = 0.0,
                 boot_time: float = .35,
                 noise_up_time: float = .1,
                 noise_down_time: float = .01,
                 noise_reduction_amount: float = 1.35,
                 measure_freq: float = 20.0,
                 measure_duration: Optional[float] = None,
                 measure_smooth_time: float = .4,
                 hp_filter_freq: float = 50.,
                 lp_filter_freq: float = 6000.,
                 hp_lifter_freq: float = 150.,
                 lp_lifter_freq: float = 2000.) -> None:
        super().__init__()

        self.sample_rate = sample_rate
        self.trigger_level = trigger_level
        self.trigger_time = trigger_time
        self.search_time = search_time
        self.allowed_gap = allowed_gap
        self.pre_trigger_time = pre_trigger_time
        self.boot_time = boot_time
        self.noise_up_time = noise_up_time
1068
        self.noise_down_time = noise_down_time
Artyom Astafurov's avatar
Artyom Astafurov committed
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
        self.noise_reduction_amount = noise_reduction_amount
        self.measure_freq = measure_freq
        self.measure_duration = measure_duration
        self.measure_smooth_time = measure_smooth_time
        self.hp_filter_freq = hp_filter_freq
        self.lp_filter_freq = lp_filter_freq
        self.hp_lifter_freq = hp_lifter_freq
        self.lp_lifter_freq = lp_lifter_freq

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension `(..., time)`
        """
        return F.vad(
            waveform=waveform,
            sample_rate=self.sample_rate,
            trigger_level=self.trigger_level,
            trigger_time=self.trigger_time,
            search_time=self.search_time,
            allowed_gap=self.allowed_gap,
            pre_trigger_time=self.pre_trigger_time,
            boot_time=self.boot_time,
            noise_up_time=self.noise_up_time,
1093
            noise_down_time=self.noise_down_time,
Artyom Astafurov's avatar
Artyom Astafurov committed
1094
1095
1096
1097
1098
1099
1100
1101
1102
            noise_reduction_amount=self.noise_reduction_amount,
            measure_freq=self.measure_freq,
            measure_duration=self.measure_duration,
            measure_smooth_time=self.measure_smooth_time,
            hp_filter_freq=self.hp_filter_freq,
            lp_filter_freq=self.lp_filter_freq,
            hp_lifter_freq=self.hp_lifter_freq,
            lp_lifter_freq=self.lp_lifter_freq,
        )
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116


class SpectralCentroid(torch.nn.Module):
    r"""Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        sample_rate (int): Sample rate of audio signal.
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
1117
1118
1119
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
        >>> spectral_centroid = transforms.SpectralCentroid(sample_rate)(waveform)  # (channel, time)
    """
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad']

    def __init__(self,
                 sample_rate: int,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
1133
1134
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
1135
1136
1137
1138
1139
        super(SpectralCentroid, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
1140
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
        self.register_buffer('window', window)
        self.pad = pad

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Spectral Centroid of size (..., time).
        """

        return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
                                   self.win_length)