transforms.py 47.1 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
5
from typing import Callable, Optional

David Pollack's avatar
David Pollack committed
6
import torch
7
from torch import Tensor
8
9
from torchaudio import functional as F
from torchaudio.compliance import kaldi
10

Jason Lian's avatar
Jason Lian committed
11

12
13
__all__ = [
    'Spectrogram',
14
    'GriffinLim',
15
    'AmplitudeToDB',
16
    'MelScale',
moto's avatar
moto committed
17
    'InverseMelScale',
18
19
20
21
22
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
23
24
    'ComplexNorm',
    'TimeStretch',
Tomás Osório's avatar
Tomás Osório committed
25
    'Fade',
26
27
    'FrequencyMasking',
    'TimeMasking',
wanglong001's avatar
wanglong001 committed
28
    'SlidingWindowCmn',
Artyom Astafurov's avatar
Artyom Astafurov committed
29
    'Vad',
30
    'SpectralCentroid',
31
32
33
]


34
class Spectrogram(torch.nn.Module):
35
    r"""Create a spectrogram from a audio signal.
36
37

    Args:
38
39
40
41
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
42
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
43
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
44
45
46
47
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
48
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
49
50
51
52
53
54
55
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy Default: ``True``
56
    """
57
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
58

59
60
61
62
63
64
65
66
    def __init__(self,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: Optional[float] = 2.,
                 normalized: bool = False,
67
68
69
70
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
                 onesided: bool = True) -> None:
71
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
72
        self.n_fft = n_fft
73
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
74
        # number of frequencies due to onesided=True in torch.stft
75
76
77
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
78
        self.register_buffer('window', window)
79
        self.pad = pad
PCerles's avatar
PCerles committed
80
        self.power = power
81
        self.normalized = normalized
82
83
84
        self.center = center
        self.pad_mode = pad_mode
        self.onesided = onesided
85

86
    def forward(self, waveform: Tensor) -> Tensor:
87
        r"""
88
        Args:
89
            waveform (Tensor): Tensor of audio of dimension (..., time).
90
91

        Returns:
92
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
93
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
94
            Fourier bins, and time is the number of window hops (n_frame).
95
        """
96
97
98
99
100
101
102
103
104
105
106
107
108
        return F.spectrogram(
            waveform,
            self.pad,
            self.window,
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.power,
            self.normalized,
            self.center,
            self.pad_mode,
            self.onesided
        )
109
110


111
112
113
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.

114
    Implementation ported from ``librosa`` [1]_, [2]_, [3]_.
115
116

    Args:
117
118
119
120
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
121
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
122
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
123
        power (float, optional): Exponent for the magnitude spectrogram,
124
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
125
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
126
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
127
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
128
            Setting this to 0 recovers the original Griffin-Lim method.
129
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
130
        length (int, optional): Array length of the expected output. (Default: ``None``)
131
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    References:
        .. [1]
           | McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg,
             and Oriol Nieto.
           | "librosa: Audio and music signal analysis in python."
           | In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

        .. [2]
           | Perraudin, N., Balazs, P., & Søndergaard, P. L.
           | "A fast Griffin-Lim algorithm,"
           | IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
           | Oct. 2013.

        .. [3]
           | D. W. Griffin and J. S. Lim,
           | "Signal estimation from modified short-time Fourier transform,"
           | IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
150
151
152
153
    """
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
                     'length', 'momentum', 'rand_init']

154
155
156
157
158
159
160
161
162
163
164
165
    def __init__(self,
                 n_fft: int = 400,
                 n_iter: int = 32,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 power: float = 2.,
                 normalized: bool = False,
                 wkwargs: Optional[dict] = None,
                 momentum: float = 0.99,
                 length: Optional[int] = None,
                 rand_init: bool = True) -> None:
166
167
        super(GriffinLim, self).__init__()

168
169
        assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
        assert momentum > 0, 'momentum={} < 0'.format(momentum)
170
171
172
173
174
175
176
177
178
179
180
181
182

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.normalized = normalized
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

183
184
185
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
186
187
188
            specgram (Tensor):
                A magnitude-only STFT spectrogram of dimension (..., freq, frames)
                where freq is ``n_fft // 2 + 1``.
189
190
191
192
193

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
        return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
194
195
196
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


197
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
198
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
199

200
    This output depends on the maximum value in the input tensor, and so
201
202
203
204
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
205
        stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
206
            power being the elementwise square of the magnitude. (Default: ``'power'``)
207
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
208
            is 80. (Default: ``None``)
209
210
211
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

212
    def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
213
        super(AmplitudeToDB, self).__init__()
214
        self.stype = stype
215
216
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
217
        self.top_db = top_db
218
219
220
221
222
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

223
    def forward(self, x: Tensor) -> Tensor:
224
        r"""Numerically stable implementation from Librosa.
moto's avatar
moto committed
225
226

        https://librosa.org/doc/latest/generated/librosa.amplitude_to_db.html
227
228

        Args:
229
            x (Tensor): Input tensor before being converted to decibel scale.
230
231

        Returns:
232
            Tensor: Output tensor in decibel scale.
233
        """
234
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
235
236


237
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
238
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
239
    matrix.  This uses triangular filter banks.
240

241
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
242

243
    Args:
244
245
246
247
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
248
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
249
            if None is given.  See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
250
251
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
252
    """
253
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
254

255
256
257
258
259
    def __init__(self,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
260
261
                 n_stft: Optional[int] = None,
                 norm: Optional[str] = None) -> None:
262
        super(MelScale, self).__init__()
263
        self.n_mels = n_mels
264
265
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
266
        self.f_min = f_min
267
        self.norm = norm
268

269
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
270

271
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
272
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm)
273
        self.register_buffer('fb', fb)
274

275
    def forward(self, specgram: Tensor) -> Tensor:
276
277
        r"""
        Args:
278
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
279
280

        Returns:
281
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
282
        """
Vincent QB's avatar
Vincent QB committed
283
284
285

        # pack batch
        shape = specgram.size()
286
        specgram = specgram.reshape(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
287

288
        if self.fb.numel() == 0:
289
290
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max,
                                        self.n_mels, self.sample_rate, self.norm)
291
292
293
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
294

295
296
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
297
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
298
299

        # unpack batch
300
        mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
301

302
        return mel_specgram
303

304

moto's avatar
moto committed
305
306
307
308
309
310
311
312
313
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
314
315
316
317
318
319
320
321
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
322
323
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
moto's avatar
moto committed
324
325
326
327
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

328
329
330
331
332
333
334
335
336
    def __init__(self,
                 n_stft: int,
                 n_mels: int = 128,
                 sample_rate: int = 16000,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 max_iter: int = 100000,
                 tolerance_loss: float = 1e-5,
                 tolerance_change: float = 1e-8,
337
338
                 sgdargs: Optional[dict] = None,
                 norm: Optional[str] = None) -> None:
moto's avatar
moto committed
339
340
341
342
343
344
345
346
347
348
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

349
        assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format(f_min, self.f_max)
moto's avatar
moto committed
350

351
        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm)
moto's avatar
moto committed
352
353
        self.register_buffer('fb', fb)

354
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
355
356
        r"""
        Args:
357
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
358
359

        Returns:
360
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


400
class MelSpectrogram(torch.nn.Module):
401
402
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
403

404
    Sources
405
406
407
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
408

409
    Args:
410
411
412
413
414
415
416
417
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
418
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
419
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
420
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
421
422
423
424
425
426
427
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
428
429
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
430

431
    Example
432
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
433
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
434
    """
435
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
436

437
438
439
440
441
442
443
444
445
446
    def __init__(self,
                 sample_rate: int = 16000,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 f_min: float = 0.,
                 f_max: Optional[float] = None,
                 pad: int = 0,
                 n_mels: int = 128,
                 window_fn: Callable[..., Tensor] = torch.hann_window,
447
448
                 power: Optional[float] = 2.,
                 normalized: bool = False,
449
450
451
                 wkwargs: Optional[dict] = None,
                 center: bool = True,
                 pad_mode: str = "reflect",
452
453
                 onesided: bool = True,
                 norm: Optional[str] = None) -> None:
454
455
456
457
458
459
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
460
461
        self.power = power
        self.normalized = normalized
462
        self.n_mels = n_mels  # number of mel frequency bins
463
        self.f_max = f_max
464
465
466
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
467
                                       pad=self.pad, window_fn=window_fn, power=self.power,
468
469
                                       normalized=self.normalized, wkwargs=wkwargs,
                                       center=center, pad_mode=pad_mode, onesided=onesided)
470
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm)
471

472
    def forward(self, waveform: Tensor) -> Tensor:
473
        r"""
474
        Args:
475
            waveform (Tensor): Tensor of audio of dimension (..., time).
476
477

        Returns:
478
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
479
        """
480
481
482
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
483
484


485
class MFCC(torch.nn.Module):
486
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
487

488
489
490
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
491

492
493
494
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
495

496
    Args:
497
498
499
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
500
        norm (str, optional): norm to use. (Default: ``'ortho'``)
501
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
502
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
503
    """
504
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
505

506
507
508
509
510
511
512
    def __init__(self,
                 sample_rate: int = 16000,
                 n_mfcc: int = 40,
                 dct_type: int = 2,
                 norm: str = 'ortho',
                 log_mels: bool = False,
                 melkwargs: Optional[dict] = None) -> None:
513
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
514
515
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
516
            raise ValueError('DCT type not supported: {}'.format(dct_type))
517
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
518
519
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
520
        self.norm = norm
521
        self.top_db = 80.0
522
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
523
524

        if melkwargs is not None:
525
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
526
        else:
527
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
528
529
530

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
531
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
532
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
533
534
        self.log_mels = log_mels

535
    def forward(self, waveform: Tensor) -> Tensor:
536
        r"""
PCerles's avatar
PCerles committed
537
        Args:
538
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
539
540

        Returns:
541
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
542
        """
543
        mel_specgram = self.MelSpectrogram(waveform)
544
545
        if self.log_mels:
            log_offset = 1e-6
546
            mel_specgram = torch.log(mel_specgram + log_offset)
547
        else:
548
            mel_specgram = self.amplitude_to_DB(mel_specgram)
Vincent QB's avatar
Vincent QB committed
549

550
551
        # (..., channel, n_mels, time).transpose(...) dot (n_mels, n_mfcc)
        # -> (..., channel, time, n_mfcc).transpose(...)
552
        mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1)
553
        return mfcc
554
555


556
class MuLawEncoding(torch.nn.Module):
557
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
558
559
560
561
562
563
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
564
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
565
    """
566
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
567

568
    def __init__(self, quantization_channels: int = 256) -> None:
569
        super(MuLawEncoding, self).__init__()
570
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
571

572
    def forward(self, x: Tensor) -> Tensor:
573
        r"""
David Pollack's avatar
David Pollack committed
574
        Args:
575
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
576
577

        Returns:
578
            x_mu (Tensor): An encoded signal.
David Pollack's avatar
David Pollack committed
579
        """
580
        return F.mu_law_encoding(x, self.quantization_channels)
581

Soumith Chintala's avatar
Soumith Chintala committed
582

583
class MuLawDecoding(torch.nn.Module):
584
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
585
586
587
588
589
590
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
591
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
David Pollack's avatar
David Pollack committed
592
    """
593
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
594

595
    def __init__(self, quantization_channels: int = 256) -> None:
596
        super(MuLawDecoding, self).__init__()
597
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
598

599
    def forward(self, x_mu: Tensor) -> Tensor:
600
        r"""
David Pollack's avatar
David Pollack committed
601
        Args:
602
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
603
604

        Returns:
605
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
606
        """
607
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
608
609
610


class Resample(torch.nn.Module):
611
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
612
613

    Args:
614
615
616
        orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (float, optional): The desired frequency. (Default: ``16000``)
        resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
617
    """
618

619
620
621
622
    def __init__(self,
                 orig_freq: int = 16000,
                 new_freq: int = 16000,
                 resampling_method: str = 'sinc_interpolation') -> None:
jamarshon's avatar
jamarshon committed
623
624
625
626
627
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

628
    def forward(self, waveform: Tensor) -> Tensor:
629
        r"""
jamarshon's avatar
jamarshon committed
630
        Args:
631
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
632
633

        Returns:
634
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
635
636
        """
        if self.resampling_method == 'sinc_interpolation':
Vincent QB's avatar
Vincent QB committed
637
638
639
640
641
642
643
644
645
646
647

            # pack batch
            shape = waveform.size()
            waveform = waveform.view(-1, shape[-1])

            waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

            # unpack batch
            waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

            return waveform
jamarshon's avatar
jamarshon committed
648

649
        raise ValueError('Invalid resampling method: {}'.format(self.resampling_method))
Vincent QB's avatar
Vincent QB committed
650
651


652
class ComplexNorm(torch.nn.Module):
653
654
    r"""Compute the norm of complex tensor input.

655
    Args:
656
        power (float, optional): Power of the norm. (Default: to ``1.0``)
657
658
659
    """
    __constants__ = ['power']

660
    def __init__(self, power: float = 1.0) -> None:
661
662
663
        super(ComplexNorm, self).__init__()
        self.power = power

664
    def forward(self, complex_tensor: Tensor) -> Tensor:
665
666
        r"""
        Args:
667
668
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.

669
        Returns:
670
            Tensor: norm of the input tensor, shape of `(..., )`.
671
672
673
674
        """
        return F.complex_norm(complex_tensor, self.power)


675
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
676
677
678
679
680
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
681
682
        win_length (int): The window length used for computing delta. (Default: ``5``)
        mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
683
684
685
    """
    __constants__ = ['win_length']

686
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
687
688
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
689
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
690

691
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
692
693
        r"""
        Args:
694
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
695
696

        Returns:
697
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
698
699
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
700
701


702
class TimeStretch(torch.nn.Module):
703
704
705
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
706
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
707
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
708
        fixed_rate (float or None, optional): rate to speed up or slow down by.
709
710
711
712
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

713
714
715
716
    def __init__(self,
                 hop_length: Optional[int] = None,
                 n_freq: int = 201,
                 fixed_rate: Optional[float] = None) -> None:
717
718
719
720
721
722
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
723
        self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
724

725
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
726
727
        r"""
        Args:
728
729
730
            complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
731
732

        Returns:
733
            Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
734
        """
735
        assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
736
737
738
739
740
741
742
743
744
745
746
747

        if overriding_rate is None:
            rate = self.fixed_rate
            if rate is None:
                raise ValueError("If no fixed_rate is specified"
                                 ", must pass a valid rate to the forward method.")
        else:
            rate = overriding_rate

        if rate == 1.0:
            return complex_specgrams

748
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
749
750


Tomás Osório's avatar
Tomás Osório committed
751
752
753
754
755
756
757
758
759
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
            "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
    """
760
761
762
763
    def __init__(self,
                 fade_in_len: int = 0,
                 fade_out_len: int = 0,
                 fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
764
765
766
767
768
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

769
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
770
771
        r"""
        Args:
772
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
773
774

        Returns:
775
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
776
777
        """
        waveform_length = waveform.size()[-1]
moto's avatar
moto committed
778
779
780
        device = waveform.device
        return self._fade_in(waveform_length).to(device) * \
            self._fade_out(waveform_length).to(device) * waveform
Tomás Osório's avatar
Tomás Osório committed
781

782
    def _fade_in(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

803
    def _fade_out(self, waveform_length: int) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


825
class _AxisMasking(torch.nn.Module):
826
827
    r"""Apply masking to a spectrogram.

828
    Args:
829
830
831
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
832
            This option is applicable only when the input tensor is 4D.
833
834
835
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

836
    def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
837
838
839
840
841
842

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

843
    def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
844
845
        r"""
        Args:
846
            specgram (Tensor): Tensor of dimension (..., freq, time).
847
            mask_value (float): Value to assign to the masked columns.
848
849

        Returns:
850
            Tensor: Masked spectrogram of dimensions (..., freq, time).
851
852
853
854
855
        """
        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
856
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
857
858
859


class FrequencyMasking(_AxisMasking):
860
861
    r"""Apply masking to a spectrogram in the frequency domain.

862
863
864
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
865
866
867
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
868
    """
869
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
870
871
872
873
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
874
875
    r"""Apply masking to a spectrogram in the time domain.

876
877
878
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
879
880
881
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
882
    """
883
    def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
884
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
Tomás Osório's avatar
Tomás Osório committed
885
886
887
888
889
890
891


class Vol(torch.nn.Module):
    r"""Add a volume to an waveform.

    Args:
        gain (float): Interpreted according to the given gain_type:
Vincent QB's avatar
Vincent QB committed
892
893
894
895
            If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio.
            If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared).
            If ``gain_type`` = ``db``, ``gain`` is in decibels.
        gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
Tomás Osório's avatar
Tomás Osório committed
896
897
    """

898
    def __init__(self, gain: float, gain_type: str = 'amplitude'):
Tomás Osório's avatar
Tomás Osório committed
899
900
901
902
903
904
905
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

        if gain_type in ['amplitude', 'power'] and gain < 0:
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

906
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
907
908
        r"""
        Args:
909
            waveform (Tensor): Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
910
911

        Returns:
912
            Tensor: Tensor of audio of dimension (..., time).
Tomás Osório's avatar
Tomás Osório committed
913
914
915
916
917
918
919
920
921
922
923
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)
wanglong001's avatar
wanglong001 committed
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960


class SlidingWindowCmn(torch.nn.Module):
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
    """

    def __init__(self,
                 cmn_window: int = 600,
                 min_cmn_window: int = 100,
                 center: bool = False,
                 norm_vars: bool = False) -> None:
        super().__init__()
        self.cmn_window = cmn_window
        self.min_cmn_window = min_cmn_window
        self.center = center
        self.norm_vars = norm_vars

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Tensor of audio of dimension (..., time).
        """
        cmn_waveform = F.sliding_window_cmn(
            waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
        return cmn_waveform
Artyom Astafurov's avatar
Artyom Astafurov committed
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042


class Vad(torch.nn.Module):
    r"""Voice Activity Detector. Similar to SoX implementation.
    Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
    The algorithm currently uses a simple cepstral power measurement to detect voice,
    so may be fooled by other things, especially music.

    The effect can trim only from the front of the audio,
    so in order to trim from the back, the reverse effect must also be used.

    Args:
        sample_rate (int): Sample rate of audio signal.
        trigger_level (float, optional): The measurement level used to trigger activity detection.
            This may need to be cahnged depending on the noise level, signal level,
            and other characteristics of the input audio. (Default: 7.0)
        trigger_time (float, optional): The time constant (in seconds)
            used to help ignore short bursts of sound. (Default: 0.25)
        search_time (float, optional): The amount of audio (in seconds)
            to search for quieter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 1.0)
        allowed_gap (float, optional): The allowed gap (in seconds) between
            quiteter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 0.25)
        pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
            before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
        boot_time (float, optional) The algorithm (internally) uses adaptive noise
            estimation/reduction in order to detect the start of the wanted audio.
            This option sets the time for the initial noise estimate. (Default: 0.35)
        noise_up_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is increasing. (Default: 0.1)
        noise_down_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is decreasing. (Default: 0.01)
        noise_reduction_amount (float, optional) Amount of noise reduction to use in
            the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
        measure_freq (float, optional) Frequency of the algorithm’s
            processing/measurements. (Default: 20.0)
        measure_duration: (float, optional) Measurement duration.
            (Default: Twice the measurement period; i.e. with overlap.)
        measure_smooth_time (float, optional) Time constant used to smooth
            spectral measurements. (Default: 0.4)
        hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
            at the input to the detector algorithm. (Default: 50.0)
        lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
            at the input to the detector algorithm. (Default: 6000.0)
        hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
            in the detector algorithm. (Default: 150.0)
        lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
            in the detector algorithm. (Default: 2000.0)

    References:
        http://sox.sourceforge.net/sox.html
    """

    def __init__(self,
                 sample_rate: int,
                 trigger_level: float = 7.0,
                 trigger_time: float = 0.25,
                 search_time: float = 1.0,
                 allowed_gap: float = 0.25,
                 pre_trigger_time: float = 0.0,
                 boot_time: float = .35,
                 noise_up_time: float = .1,
                 noise_down_time: float = .01,
                 noise_reduction_amount: float = 1.35,
                 measure_freq: float = 20.0,
                 measure_duration: Optional[float] = None,
                 measure_smooth_time: float = .4,
                 hp_filter_freq: float = 50.,
                 lp_filter_freq: float = 6000.,
                 hp_lifter_freq: float = 150.,
                 lp_lifter_freq: float = 2000.) -> None:
        super().__init__()

        self.sample_rate = sample_rate
        self.trigger_level = trigger_level
        self.trigger_time = trigger_time
        self.search_time = search_time
        self.allowed_gap = allowed_gap
        self.pre_trigger_time = pre_trigger_time
        self.boot_time = boot_time
        self.noise_up_time = noise_up_time
1043
        self.noise_down_time = noise_down_time
Artyom Astafurov's avatar
Artyom Astafurov committed
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
        self.noise_reduction_amount = noise_reduction_amount
        self.measure_freq = measure_freq
        self.measure_duration = measure_duration
        self.measure_smooth_time = measure_smooth_time
        self.hp_filter_freq = hp_filter_freq
        self.lp_filter_freq = lp_filter_freq
        self.hp_lifter_freq = hp_lifter_freq
        self.lp_lifter_freq = lp_lifter_freq

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension `(..., time)`
        """
        return F.vad(
            waveform=waveform,
            sample_rate=self.sample_rate,
            trigger_level=self.trigger_level,
            trigger_time=self.trigger_time,
            search_time=self.search_time,
            allowed_gap=self.allowed_gap,
            pre_trigger_time=self.pre_trigger_time,
            boot_time=self.boot_time,
            noise_up_time=self.noise_up_time,
1068
            noise_down_time=self.noise_down_time,
Artyom Astafurov's avatar
Artyom Astafurov committed
1069
1070
1071
1072
1073
1074
1075
1076
1077
            noise_reduction_amount=self.noise_reduction_amount,
            measure_freq=self.measure_freq,
            measure_duration=self.measure_duration,
            measure_smooth_time=self.measure_smooth_time,
            hp_filter_freq=self.hp_filter_freq,
            lp_filter_freq=self.lp_filter_freq,
            hp_lifter_freq=self.hp_lifter_freq,
            lp_lifter_freq=self.lp_lifter_freq,
        )
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091


class SpectralCentroid(torch.nn.Module):
    r"""Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        sample_rate (int): Sample rate of audio signal.
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
1092
1093
1094
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
        >>> spectral_centroid = transforms.SpectralCentroid(sample_rate)(waveform)  # (channel, time)
    """
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad']

    def __init__(self,
                 sample_rate: int,
                 n_fft: int = 400,
                 win_length: Optional[int] = None,
                 hop_length: Optional[int] = None,
                 pad: int = 0,
1108
1109
                 window_fn: Callable[..., Tensor] = torch.hann_window,
                 wkwargs: Optional[dict] = None) -> None:
1110
1111
1112
1113
1114
        super(SpectralCentroid, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
1115
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        self.register_buffer('window', window)
        self.pad = pad

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Spectral Centroid of size (..., time).
        """

        return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
                                   self.win_length)