_transforms.py 67.8 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
import warnings
5
from typing import Callable, Optional, Union
6

David Pollack's avatar
David Pollack committed
7
import torch
8
from torch import Tensor
9
10
11
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import UninitializedParameter

12
from torchaudio import functional as F
13
14
from torchaudio.functional.functional import (
    _apply_sinc_resample_kernel,
15
    _fix_waveform_shape,
16
    _get_sinc_resample_kernel,
17
    _stretch_waveform,
18
)
Jason Lian's avatar
Jason Lian committed
19

20
__all__ = []
21
22


23
class Spectrogram(torch.nn.Module):
24
    r"""Create a spectrogram from a audio signal.
25

moto's avatar
moto committed
26
27
28
29
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

30
    Args:
31
32
33
34
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
35
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
36
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
37
38
39
        power (float or None, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
            If None, then the complex spectrum is returned instead. (Default: ``2``)
40
41
42
        normalized (bool or str, optional): Whether to normalize by magnitude after stft. If input is str, choices are
            ``"window"`` and ``"frame_length"``, if specific normalization type is desirable. ``True`` maps to
            ``"window"``. (Default: ``False``)
43
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
44
45
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Caroline Chen's avatar
Caroline Chen committed
46
            (Default: ``True``)
47
        pad_mode (string, optional): controls the padding method used when
Caroline Chen's avatar
Caroline Chen committed
48
            :attr:`center` is ``True``. (Default: ``"reflect"``)
49
        onesided (bool, optional): controls whether to return half of results to
Caroline Chen's avatar
Caroline Chen committed
50
            avoid redundancy (Default: ``True``)
51
        return_complex (bool, optional):
52
            Deprecated and not used.
S Harish's avatar
S Harish committed
53
54
55
56
57
58

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = torchaudio.transforms.Spectrogram(n_fft=800)
        >>> spectrogram = transform(waveform)

59
    """
60
61
62
63
64
65
66
67
68
69
    __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]

    def __init__(
        self,
        n_fft: int = 400,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        pad: int = 0,
        window_fn: Callable[..., Tensor] = torch.hann_window,
        power: Optional[float] = 2.0,
70
        normalized: Union[bool, str] = False,
71
72
73
74
75
76
        wkwargs: Optional[dict] = None,
        center: bool = True,
        pad_mode: str = "reflect",
        onesided: bool = True,
        return_complex: Optional[bool] = None,
    ) -> None:
77
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
78
        self.n_fft = n_fft
79
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
80
        # number of frequencies due to onesided=True in torch.stft
81
82
83
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
84
        self.register_buffer("window", window)
85
        self.pad = pad
PCerles's avatar
PCerles committed
86
        self.power = power
87
        self.normalized = normalized
88
89
90
        self.center = center
        self.pad_mode = pad_mode
        self.onesided = onesided
91
92
93
94
95
96
        if return_complex is not None:
            warnings.warn(
                "`return_complex` argument is now deprecated and is not effective."
                "`torchaudio.transforms.Spectrogram(power=None)` always returns a tensor with "
                "complex dtype. Please remove the argument in the function call."
            )
97

98
    def forward(self, waveform: Tensor) -> Tensor:
99
        r"""
100
        Args:
101
            waveform (Tensor): Tensor of audio of dimension (..., time).
102
103

        Returns:
104
            Tensor: Dimension (..., freq, time), where freq is
Vincent QB's avatar
Vincent QB committed
105
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
106
            Fourier bins, and time is the number of window hops (n_frame).
107
        """
108
109
110
111
112
113
114
115
116
117
118
        return F.spectrogram(
            waveform,
            self.pad,
            self.window,
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.power,
            self.normalized,
            self.center,
            self.pad_mode,
119
            self.onesided,
120
        )
121
122


123
124
125
class InverseSpectrogram(torch.nn.Module):
    r"""Create an inverse spectrogram to recover an audio signal from a spectrogram.

moto's avatar
moto committed
126
127
128
129
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

130
131
132
133
134
135
136
    Args:
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
137
138
139
        normalized (bool or str, optional): Whether the stft output was normalized by magnitude. If input is str,
            choices are ``"window"`` and ``"frame_length"``, dependent on normalization mode. ``True`` maps to
            ``"window"``. (Default: ``False``)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
        center (bool, optional): whether the signal in spectrogram was padded on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            (Default: ``True``)
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. (Default: ``"reflect"``)
        onesided (bool, optional): controls whether spectrogram was used to return half of results to
            avoid redundancy (Default: ``True``)

    Example
        >>> batch, freq, time = 2, 257, 100
        >>> length = 25344
        >>> spectrogram = torch.randn(batch, freq, time, dtype=torch.cdouble)
        >>> transform = transforms.InverseSpectrogram(n_fft=512)
        >>> waveform = transform(spectrogram, length)
    """
156
157
158
159
160
161
162
163
164
    __constants__ = ["n_fft", "win_length", "hop_length", "pad", "power", "normalized"]

    def __init__(
        self,
        n_fft: int = 400,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        pad: int = 0,
        window_fn: Callable[..., Tensor] = torch.hann_window,
165
        normalized: Union[bool, str] = False,
166
167
168
169
170
        wkwargs: Optional[dict] = None,
        center: bool = True,
        pad_mode: str = "reflect",
        onesided: bool = True,
    ) -> None:
171
172
173
174
175
176
177
        super(InverseSpectrogram, self).__init__()
        self.n_fft = n_fft
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequencies due to onesided=True in torch.stft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
178
        self.register_buffer("window", window)
179
180
181
182
183
184
185
186
187
188
        self.pad = pad
        self.normalized = normalized
        self.center = center
        self.pad_mode = pad_mode
        self.onesided = onesided

    def forward(self, spectrogram: Tensor, length: Optional[int] = None) -> Tensor:
        r"""
        Args:
            spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time).
189
            length (int or None, optional): The output length of the waveform.
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

        Returns:
            Tensor: Dimension (..., time), Least squares estimation of the original signal.
        """
        return F.inverse_spectrogram(
            spectrogram,
            length,
            self.pad,
            self.window,
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.normalized,
            self.center,
            self.pad_mode,
            self.onesided,
        )


209
210
211
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.

moto's avatar
moto committed
212
213
214
215
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

moto's avatar
moto committed
216
    Implementation ported from
217
218
    *librosa* [:footcite:`brian_mcfee-proc-scipy-2015`], *A fast Griffin-Lim algorithm* [:footcite:`6701851`]
    and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
219
220

    Args:
221
222
223
224
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
225
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
226
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
227
        power (float, optional): Exponent for the magnitude spectrogram,
228
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
229
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
230
        momentum (float, optional): The momentum parameter for fast Griffin-Lim.
231
            Setting this to 0 recovers the original Griffin-Lim method.
232
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
233
        length (int, optional): Array length of the expected output. (Default: ``None``)
234
        rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
S Harish's avatar
S Harish committed
235
236
237
238
239
240

    Example
        >>> batch, freq, time = 2, 257, 100
        >>> spectrogram = torch.randn(batch, freq, time)
        >>> transform = transforms.GriffinLim(n_fft=512)
        >>> waveform = transform(spectrogram)
241
    """
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    __constants__ = ["n_fft", "n_iter", "win_length", "hop_length", "power", "length", "momentum", "rand_init"]

    def __init__(
        self,
        n_fft: int = 400,
        n_iter: int = 32,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        window_fn: Callable[..., Tensor] = torch.hann_window,
        power: float = 2.0,
        wkwargs: Optional[dict] = None,
        momentum: float = 0.99,
        length: Optional[int] = None,
        rand_init: bool = True,
    ) -> None:
257
258
        super(GriffinLim, self).__init__()

259
260
        assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
        assert momentum >= 0, "momentum={} < 0".format(momentum)
261
262
263
264
265
266

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
267
        self.register_buffer("window", window)
268
269
        self.length = length
        self.power = power
270
        self.momentum = momentum
271
272
        self.rand_init = rand_init

273
274
275
    def forward(self, specgram: Tensor) -> Tensor:
        r"""
        Args:
276
277
278
            specgram (Tensor):
                A magnitude-only STFT spectrogram of dimension (..., freq, frames)
                where freq is ``n_fft // 2 + 1``.
279
280
281
282

        Returns:
            Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
        """
283
284
285
286
287
288
289
290
291
292
293
294
        return F.griffinlim(
            specgram,
            self.window,
            self.n_fft,
            self.hop_length,
            self.win_length,
            self.power,
            self.n_iter,
            self.momentum,
            self.length,
            self.rand_init,
        )
295
296


297
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
298
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
299

moto's avatar
moto committed
300
301
302
303
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

304
    This output depends on the maximum value in the input tensor, and so
305
306
307
308
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
309
        stype (str, optional): scale of input tensor (``'power'`` or ``'magnitude'``). The
310
            power being the elementwise square of the magnitude. (Default: ``'power'``)
311
312
        top_db (float or None, optional): minimum negative cut-off in decibels.  A reasonable
            number is 80. (Default: ``None``)
313
    """
314
    __constants__ = ["multiplier", "amin", "ref_value", "db_multiplier"]
315

316
    def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
317
        super(AmplitudeToDB, self).__init__()
318
        self.stype = stype
319
        if top_db is not None and top_db < 0:
320
            raise ValueError("top_db must be positive value")
321
        self.top_db = top_db
322
        self.multiplier = 10.0 if stype == "power" else 20.0
323
324
325
326
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

327
    def forward(self, x: Tensor) -> Tensor:
328
        r"""Numerically stable implementation from Librosa.
moto's avatar
moto committed
329
330

        https://librosa.org/doc/latest/generated/librosa.amplitude_to_db.html
331
332

        Args:
333
            x (Tensor): Input tensor before being converted to decibel scale.
334
335

        Returns:
336
            Tensor: Output tensor in decibel scale.
337
        """
338
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
339
340


341
class MelScale(torch.nn.Module):
moto's avatar
moto committed
342
343
344
345
346
    r"""Turn a normal STFT into a mel frequency STFT with triangular filter banks.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript
347
348

    Args:
349
350
351
352
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
353
        n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
354
        norm (str or None, optional): If ``'slaney'``, divide the triangular mel weights by the width of the mel band
355
            (area normalization). (Default: ``None``)
356
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
moto's avatar
moto committed
357
358
359
360

    See also:
        :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
        generate the filter banks.
361
    """
362
363
364
365
366
367
368
369
370
371
372
373
    __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]

    def __init__(
        self,
        n_mels: int = 128,
        sample_rate: int = 16000,
        f_min: float = 0.0,
        f_max: Optional[float] = None,
        n_stft: int = 201,
        norm: Optional[str] = None,
        mel_scale: str = "htk",
    ) -> None:
374
        super(MelScale, self).__init__()
375
        self.n_mels = n_mels
376
377
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
378
        self.f_min = f_min
379
        self.norm = norm
380
        self.mel_scale = mel_scale
381

382
383
384
        assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
        fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.norm, self.mel_scale)
        self.register_buffer("fb", fb)
385

386
    def forward(self, specgram: Tensor) -> Tensor:
387
388
        r"""
        Args:
389
            specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
390
391

        Returns:
392
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
393
        """
Vincent QB's avatar
Vincent QB committed
394

395
396
        # (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
        mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
Vincent QB's avatar
Vincent QB committed
397

398
        return mel_specgram
399

400

moto's avatar
moto committed
401
class InverseMelScale(torch.nn.Module):
moto's avatar
moto committed
402
403
404
    r"""Estimate a STFT in normal frequency domain from mel frequency domain.

    .. devices:: CPU CUDA
moto's avatar
moto committed
405
406
407
408
409
410

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
411
412
413
414
415
416
417
418
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
        tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
        tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
        sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
419
        norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
Caroline Chen's avatar
Caroline Chen committed
420
            (area normalization). (Default: ``None``)
421
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
moto's avatar
moto committed
422
    """
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    __constants__ = [
        "n_stft",
        "n_mels",
        "sample_rate",
        "f_min",
        "f_max",
        "max_iter",
        "tolerance_loss",
        "tolerance_change",
        "sgdargs",
    ]

    def __init__(
        self,
        n_stft: int,
        n_mels: int = 128,
        sample_rate: int = 16000,
        f_min: float = 0.0,
        f_max: Optional[float] = None,
        max_iter: int = 100000,
        tolerance_loss: float = 1e-5,
        tolerance_change: float = 1e-8,
        sgdargs: Optional[dict] = None,
        norm: Optional[str] = None,
        mel_scale: str = "htk",
    ) -> None:
moto's avatar
moto committed
449
450
451
452
453
454
455
456
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
457
        self.sgdargs = sgdargs or {"lr": 0.1, "momentum": 0.9}
moto's avatar
moto committed
458

459
        assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(f_min, self.f_max)
moto's avatar
moto committed
460

461
462
        fb = F.melscale_fbanks(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, norm, mel_scale)
        self.register_buffer("fb", fb)
moto's avatar
moto committed
463

464
    def forward(self, melspec: Tensor) -> Tensor:
moto's avatar
moto committed
465
466
        r"""
        Args:
467
            melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
moto's avatar
moto committed
468
469

        Returns:
470
            Tensor: Linear scale spectrogram of size (..., freq, time)
moto's avatar
moto committed
471
472
473
474
475
476
477
478
479
480
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

481
482
483
        specgram = torch.rand(
            melspec.size()[0], time, freq, requires_grad=True, dtype=melspec.dtype, device=melspec.device
        )
moto's avatar
moto committed
484
485
486

        optim = torch.optim.SGD([specgram], **self.sgdargs)

487
        loss = float("inf")
moto's avatar
moto committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


511
class MelSpectrogram(torch.nn.Module):
moto's avatar
moto committed
512
513
    r"""Create MelSpectrogram for a raw audio signal.

moto's avatar
moto committed
514
515
516
517
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

moto's avatar
moto committed
518
519
    This is a composition of :py:func:`torchaudio.transforms.Spectrogram` and
    and :py:func:`torchaudio.transforms.MelScale`.
520

521
    Sources
522
523
524
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
525

526
    Args:
527
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
Caroline Chen's avatar
Caroline Chen committed
528
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
529
530
531
532
533
534
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
        n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
535
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
536
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
Caroline Chen's avatar
Caroline Chen committed
537
538
539
        power (float, optional): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
540
        wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
541
542
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
Caroline Chen's avatar
Caroline Chen committed
543
            (Default: ``True``)
544
        pad_mode (string, optional): controls the padding method used when
Caroline Chen's avatar
Caroline Chen committed
545
            :attr:`center` is ``True``. (Default: ``"reflect"``)
546
        onesided (bool, optional): controls whether to return half of results to
Caroline Chen's avatar
Caroline Chen committed
547
            avoid redundancy. (Default: ``True``)
548
        norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
Caroline Chen's avatar
Caroline Chen committed
549
            (area normalization). (Default: ``None``)
550
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
551

552
    Example
nateanl's avatar
nateanl committed
553
554
555
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.MelSpectrogram(sample_rate)
        >>> mel_specgram = transform(waveform)  # (channel, n_mels, time)
moto's avatar
moto committed
556
557
558
559

    See also:
        :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
        generate the filter banks.
560
    """
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
    __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad", "n_mels", "f_min"]

    def __init__(
        self,
        sample_rate: int = 16000,
        n_fft: int = 400,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        f_min: float = 0.0,
        f_max: Optional[float] = None,
        pad: int = 0,
        n_mels: int = 128,
        window_fn: Callable[..., Tensor] = torch.hann_window,
        power: float = 2.0,
        normalized: bool = False,
        wkwargs: Optional[dict] = None,
        center: bool = True,
        pad_mode: str = "reflect",
        onesided: bool = True,
        norm: Optional[str] = None,
        mel_scale: str = "htk",
    ) -> None:
583
584
585
586
587
588
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
589
590
        self.power = power
        self.normalized = normalized
591
        self.n_mels = n_mels  # number of mel frequency bins
592
        self.f_max = f_max
593
        self.f_min = f_min
594
595
596
597
598
599
600
601
602
603
604
605
606
        self.spectrogram = Spectrogram(
            n_fft=self.n_fft,
            win_length=self.win_length,
            hop_length=self.hop_length,
            pad=self.pad,
            window_fn=window_fn,
            power=self.power,
            normalized=self.normalized,
            wkwargs=wkwargs,
            center=center,
            pad_mode=pad_mode,
            onesided=onesided,
        )
607
        self.mel_scale = MelScale(
608
            self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1, norm, mel_scale
609
        )
610

611
    def forward(self, waveform: Tensor) -> Tensor:
612
        r"""
613
        Args:
614
            waveform (Tensor): Tensor of audio of dimension (..., time).
615
616

        Returns:
617
            Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
618
        """
619
620
621
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
622
623


624
class MFCC(torch.nn.Module):
625
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
PCerles's avatar
PCerles committed
626

moto's avatar
moto committed
627
628
629
630
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

631
632
633
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
634

635
636
637
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
638

639
    Args:
640
641
642
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
643
        norm (str, optional): norm to use. (Default: ``'ortho'``)
644
        log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
645
        melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
moto's avatar
moto committed
646
647
648
649

    See also:
        :py:func:`torchaudio.functional.melscale_fbanks` - The function used to
        generate the filter banks.
PCerles's avatar
PCerles committed
650
    """
651
652
653
654
655
656
657
658
659
660
661
    __constants__ = ["sample_rate", "n_mfcc", "dct_type", "top_db", "log_mels"]

    def __init__(
        self,
        sample_rate: int = 16000,
        n_mfcc: int = 40,
        dct_type: int = 2,
        norm: str = "ortho",
        log_mels: bool = False,
        melkwargs: Optional[dict] = None,
    ) -> None:
662
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
663
664
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
665
            raise ValueError("DCT type not supported: {}".format(dct_type))
666
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
667
668
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
669
        self.norm = norm
670
        self.top_db = 80.0
671
        self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
PCerles's avatar
PCerles committed
672

673
674
        melkwargs = melkwargs or {}
        self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
675
676

        if self.n_mfcc > self.MelSpectrogram.n_mels:
677
            raise ValueError("Cannot select more MFCC coefficients than # mel bins")
678
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
679
        self.register_buffer("dct_mat", dct_mat)
PCerles's avatar
PCerles committed
680
681
        self.log_mels = log_mels

682
    def forward(self, waveform: Tensor) -> Tensor:
683
        r"""
PCerles's avatar
PCerles committed
684
        Args:
685
            waveform (Tensor): Tensor of audio of dimension (..., time).
PCerles's avatar
PCerles committed
686
687

        Returns:
688
            Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
PCerles's avatar
PCerles committed
689
        """
690
        mel_specgram = self.MelSpectrogram(waveform)
691
692
        if self.log_mels:
            log_offset = 1e-6
693
            mel_specgram = torch.log(mel_specgram + log_offset)
694
        else:
695
            mel_specgram = self.amplitude_to_DB(mel_specgram)
Vincent QB's avatar
Vincent QB committed
696

697
698
        # (..., time, n_mels) dot (n_mels, n_mfcc) -> (..., n_nfcc, time)
        mfcc = torch.matmul(mel_specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
699
        return mfcc
700
701


702
703
704
class LFCC(torch.nn.Module):
    r"""Create the linear-frequency cepstrum coefficients from an audio signal.

moto's avatar
moto committed
705
706
707
708
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
    By default, this calculates the LFCC on the DB-scaled linear filtered spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.

    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
        sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
        n_filter (int, optional): Number of linear filters to apply. (Default: ``128``)
        n_lfcc (int, optional): Number of lfc coefficients to retain. (Default: ``40``)
        f_min (float, optional): Minimum frequency. (Default: ``0.``)
        f_max (float or None, optional): Maximum frequency. (Default: ``None``)
        dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
        norm (str, optional): norm to use. (Default: ``'ortho'``)
        log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``)
        speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``)
moto's avatar
moto committed
727
728
729
730
731


    See also:
        :py:func:`torchaudio.functional.linear_fbanks` - The function used to
        generate the filter banks.
732
    """
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    __constants__ = ["sample_rate", "n_filter", "n_lfcc", "dct_type", "top_db", "log_lf"]

    def __init__(
        self,
        sample_rate: int = 16000,
        n_filter: int = 128,
        f_min: float = 0.0,
        f_max: Optional[float] = None,
        n_lfcc: int = 40,
        dct_type: int = 2,
        norm: str = "ortho",
        log_lf: bool = False,
        speckwargs: Optional[dict] = None,
    ) -> None:
747
748
749
        super(LFCC, self).__init__()
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
750
            raise ValueError("DCT type not supported: {}".format(dct_type))
751
752
753
754
755
756
757
758
        self.sample_rate = sample_rate
        self.f_min = f_min
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
        self.n_filter = n_filter
        self.n_lfcc = n_lfcc
        self.dct_type = dct_type
        self.norm = norm
        self.top_db = 80.0
759
        self.amplitude_to_DB = AmplitudeToDB("power", self.top_db)
760
761
762
763
764

        speckwargs = speckwargs or {}
        self.Spectrogram = Spectrogram(**speckwargs)

        if self.n_lfcc > self.Spectrogram.n_fft:
765
            raise ValueError("Cannot select more LFCC coefficients than # fft bins")
766
767
768
769
770
771
772
773
774
775
776

        filter_mat = F.linear_fbanks(
            n_freqs=self.Spectrogram.n_fft // 2 + 1,
            f_min=self.f_min,
            f_max=self.f_max,
            n_filter=self.n_filter,
            sample_rate=self.sample_rate,
        )
        self.register_buffer("filter_mat", filter_mat)

        dct_mat = F.create_dct(self.n_lfcc, self.n_filter, self.norm)
777
        self.register_buffer("dct_mat", dct_mat)
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        self.log_lf = log_lf

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
            waveform (Tensor): Tensor of audio of dimension (..., time).

        Returns:
            Tensor: Linear Frequency Cepstral Coefficients of size (..., ``n_lfcc``, time).
        """
        specgram = self.Spectrogram(waveform)

        # (..., time, freq) dot (freq, n_filter) -> (..., n_filter, time)
        specgram = torch.matmul(specgram.transpose(-1, -2), self.filter_mat).transpose(-1, -2)

        if self.log_lf:
            log_offset = 1e-6
            specgram = torch.log(specgram + log_offset)
        else:
            specgram = self.amplitude_to_DB(specgram)

        # (..., time, n_filter) dot (n_filter, n_lfcc) -> (..., n_lfcc, time)
        lfcc = torch.matmul(specgram.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
        return lfcc


804
class MuLawEncoding(torch.nn.Module):
moto's avatar
moto committed
805
806
807
808
809
810
811
    r"""Encode signal based on mu-law companding.

    .. devices:: CPU CUDA

    .. properties:: TorchScript

    For more info see the
David Pollack's avatar
David Pollack committed
812
813
814
815
816
817
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
818
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
819
820
821
822
823
824

    Example
       >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
       >>> transform = torchaudio.transforms.MuLawEncoding(quantization_channels=512)
       >>> mulawtrans = transform(waveform)

David Pollack's avatar
David Pollack committed
825
    """
826
    __constants__ = ["quantization_channels"]
David Pollack's avatar
David Pollack committed
827

828
    def __init__(self, quantization_channels: int = 256) -> None:
829
        super(MuLawEncoding, self).__init__()
830
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
831

832
    def forward(self, x: Tensor) -> Tensor:
833
        r"""
David Pollack's avatar
David Pollack committed
834
        Args:
835
            x (Tensor): A signal to be encoded.
David Pollack's avatar
David Pollack committed
836
837

        Returns:
838
            Tensor: An encoded signal.
David Pollack's avatar
David Pollack committed
839
        """
840
        return F.mu_law_encoding(x, self.quantization_channels)
841

Soumith Chintala's avatar
Soumith Chintala committed
842

843
class MuLawDecoding(torch.nn.Module):
moto's avatar
moto committed
844
845
846
847
848
849
850
    r"""Decode mu-law encoded signal.

    .. devices:: CPU CUDA

    .. properties:: TorchScript

    For more info see the
David Pollack's avatar
David Pollack committed
851
852
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

853
    This expects an input with values between 0 and ``quantization_channels - 1``
David Pollack's avatar
David Pollack committed
854
855
856
    and returns a signal scaled between -1 and 1.

    Args:
857
        quantization_channels (int, optional): Number of channels. (Default: ``256``)
858
859
860
861
862

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512)
        >>> mulawtrans = transform(waveform)
David Pollack's avatar
David Pollack committed
863
    """
864
    __constants__ = ["quantization_channels"]
David Pollack's avatar
David Pollack committed
865

866
    def __init__(self, quantization_channels: int = 256) -> None:
867
        super(MuLawDecoding, self).__init__()
868
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
869

870
    def forward(self, x_mu: Tensor) -> Tensor:
871
        r"""
David Pollack's avatar
David Pollack committed
872
        Args:
873
            x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
David Pollack's avatar
David Pollack committed
874
875

        Returns:
876
            Tensor: The signal decoded.
David Pollack's avatar
David Pollack committed
877
        """
878
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
879
880
881


class Resample(torch.nn.Module):
882
    r"""Resample a signal from one frequency to another. A resampling method can be given.
jamarshon's avatar
jamarshon committed
883

moto's avatar
moto committed
884
885
886
887
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

Caroline Chen's avatar
Caroline Chen committed
888
889
890
891
892
893
    Note:
        If resampling on waveforms of higher precision than float32, there may be a small loss of precision
        because the kernel is cached once as float32. If high precision resampling is important for your application,
        the functional form will retain higher precision, but run slower because it does not cache the kernel.
        Alternatively, you could rewrite a transform that caches a higher precision kernel.

jamarshon's avatar
jamarshon committed
894
    Args:
895
896
        orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``)
        new_freq (int, optional): The desired frequency. (Default: ``16000``)
897
        resampling_method (str, optional): The resampling method to use.
898
            Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
899
        lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
900
            but less efficient. (Default: ``6``)
901
902
        rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
            Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
903
        beta (float or None, optional): The shape parameter used for kaiser window.
904
905
906
907
908
909
910
        dtype (torch.device, optional):
            Determnines the precision that resampling kernel is pre-computed and cached. If not provided,
            kernel is computed with ``torch.float64`` then cached as ``torch.float32``.
            If you need higher precision, provide ``torch.float64``, and the pre-computed kernel is computed and
            cached as ``torch.float64``. If you use resample with lower precision, then instead of providing this
            providing this argument, please use ``Resample.to(dtype)``, so that the kernel generation is still
            carried out on ``torch.float64``.
911
912
913
914
915

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.Resample(sample_rate, sample_rate/10)
        >>> waveform = transform(waveform)
jamarshon's avatar
jamarshon committed
916
    """
917

918
    def __init__(
919
920
921
922
923
924
925
926
927
        self,
        orig_freq: int = 16000,
        new_freq: int = 16000,
        resampling_method: str = "sinc_interpolation",
        lowpass_filter_width: int = 6,
        rolloff: float = 0.99,
        beta: Optional[float] = None,
        *,
        dtype: Optional[torch.dtype] = None,
928
929
    ) -> None:
        super().__init__()
930

jamarshon's avatar
jamarshon committed
931
932
        self.orig_freq = orig_freq
        self.new_freq = new_freq
933
        self.gcd = math.gcd(int(self.orig_freq), int(self.new_freq))
jamarshon's avatar
jamarshon committed
934
        self.resampling_method = resampling_method
935
936
        self.lowpass_filter_width = lowpass_filter_width
        self.rolloff = rolloff
937
        self.beta = beta
jamarshon's avatar
jamarshon committed
938

939
        if self.orig_freq != self.new_freq:
940
            kernel, self.width = _get_sinc_resample_kernel(
941
942
943
944
945
946
947
948
949
950
                self.orig_freq,
                self.new_freq,
                self.gcd,
                self.lowpass_filter_width,
                self.rolloff,
                self.resampling_method,
                beta,
                dtype=dtype,
            )
            self.register_buffer("kernel", kernel)
951

952
    def forward(self, waveform: Tensor) -> Tensor:
953
        r"""
jamarshon's avatar
jamarshon committed
954
        Args:
955
            waveform (Tensor): Tensor of audio of dimension (..., time).
jamarshon's avatar
jamarshon committed
956
957

        Returns:
958
            Tensor: Output signal of dimension (..., time).
jamarshon's avatar
jamarshon committed
959
        """
960
961
        if self.orig_freq == self.new_freq:
            return waveform
962
        return _apply_sinc_resample_kernel(waveform, self.orig_freq, self.new_freq, self.gcd, self.kernel, self.width)
Vincent QB's avatar
Vincent QB committed
963
964


965
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
966
967
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

moto's avatar
moto committed
968
969
970
971
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

Vincent QB's avatar
Vincent QB committed
972
973
974
    See `torchaudio.functional.compute_deltas` for more details.

    Args:
975
976
        win_length (int, optional): The window length used for computing delta. (Default: ``5``)
        mode (str, optional): Mode parameter passed to padding. (Default: ``'replicate'``)
Vincent QB's avatar
Vincent QB committed
977
    """
978
    __constants__ = ["win_length"]
Vincent QB's avatar
Vincent QB committed
979

980
    def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
Vincent QB's avatar
Vincent QB committed
981
982
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
983
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
984

985
    def forward(self, specgram: Tensor) -> Tensor:
Vincent QB's avatar
Vincent QB committed
986
987
        r"""
        Args:
988
            specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
989
990

        Returns:
991
            Tensor: Tensor of deltas of dimension (..., freq, time).
Vincent QB's avatar
Vincent QB committed
992
993
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
994
995


996
class TimeStretch(torch.nn.Module):
997
998
    r"""Stretch stft in time without modifying pitch for a given rate.

moto's avatar
moto committed
999
1000
1001
1002
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

moto's avatar
moto committed
1003
1004
    Proposed in *SpecAugment* [:footcite:`specaugment`].

1005
    Args:
1006
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
1007
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
1008
        fixed_rate (float or None, optional): rate to speed up or slow down by.
1009
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
moto's avatar
moto committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030

    Example
        >>> spectrogram = torchaudio.transforms.Spectrogram()
        >>> stretch = torchaudio.transforms.TimeStretch()
        >>>
        >>> original = spectrogram(waveform)
        >>> streched_1_2 = stretch(original, 1.2)
        >>> streched_0_9 = stretch(original, 0.9)

        .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
           :width: 600
           :alt: Spectrogram streched by 1.2

        .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png
           :width: 600
           :alt: The original spectrogram

        .. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
           :width: 600
           :alt: Spectrogram streched by 0.9

1031
    """
1032
    __constants__ = ["fixed_rate"]
1033

1034
    def __init__(self, hop_length: Optional[int] = None, n_freq: int = 201, fixed_rate: Optional[float] = None) -> None:
1035
1036
1037
1038
1039
1040
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
1041
        self.register_buffer("phase_advance", torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
1042

1043
    def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
1044
1045
        r"""
        Args:
1046
            complex_specgrams (Tensor):
1047
                A tensor of dimension `(..., freq, num_frame)` with complex dtype.
1048
1049
            overriding_rate (float or None, optional): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
1050
1051

        Returns:
1052
1053
1054
            Tensor:
                Stretched spectrogram. The resulting tensor is of the same dtype as the input
                spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
1055
1056
        """
        if overriding_rate is None:
1057
            if self.fixed_rate is None:
1058
                raise ValueError("If no fixed_rate is specified, must pass a valid rate to the forward method.")
1059
1060
1061
            rate = self.fixed_rate
        else:
            rate = overriding_rate
1062
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
1063
1064


Tomás Osório's avatar
Tomás Osório committed
1065
1066
1067
class Fade(torch.nn.Module):
    r"""Add a fade in and/or fade out to an waveform.

moto's avatar
moto committed
1068
1069
1070
1071
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

Tomás Osório's avatar
Tomás Osório committed
1072
1073
1074
1075
    Args:
        fade_in_len (int, optional): Length of fade-in (time frames). (Default: ``0``)
        fade_out_len (int, optional): Length of fade-out (time frames). (Default: ``0``)
        fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
1076
1077
            ``"half_sine"``, ``"linear"``, ``"logarithmic"``, ``"exponential"``.
            (Default: ``"linear"``)
1078
1079
1080
1081
1082

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.Fade(fade_in_len=sample_rate, fade_out_len=2 * sample_rate, fade_shape='linear')
        >>> faded_waveform = transform(waveform)
Tomás Osório's avatar
Tomás Osório committed
1083
    """
1084

1085
    def __init__(self, fade_in_len: int = 0, fade_out_len: int = 0, fade_shape: str = "linear") -> None:
Tomás Osório's avatar
Tomás Osório committed
1086
1087
1088
1089
1090
        super(Fade, self).__init__()
        self.fade_in_len = fade_in_len
        self.fade_out_len = fade_out_len
        self.fade_shape = fade_shape

1091
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
1092
1093
        r"""
        Args:
1094
            waveform (Tensor): Tensor of audio of dimension `(..., time)`.
Tomás Osório's avatar
Tomás Osório committed
1095
1096

        Returns:
1097
            Tensor: Tensor of audio of dimension `(..., time)`.
Tomás Osório's avatar
Tomás Osório committed
1098
1099
        """
        waveform_length = waveform.size()[-1]
moto's avatar
moto committed
1100
        device = waveform.device
1101
        return self._fade_in(waveform_length, device) * self._fade_out(waveform_length, device) * waveform
Tomás Osório's avatar
Tomás Osório committed
1102

1103
1104
1105
    def _fade_in(self, waveform_length: int, device: torch.device) -> Tensor:
        fade = torch.linspace(0, 1, self.fade_in_len, device=device)
        ones = torch.ones(waveform_length - self.fade_in_len, device=device)
Tomás Osório's avatar
Tomás Osório committed
1106
1107
1108
1109
1110
1111
1112
1113

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
1114
            fade = torch.log10(0.1 + fade) + 1
Tomás Osório's avatar
Tomás Osório committed
1115
1116
1117
1118
1119
1120
1121
1122
1123

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return torch.cat((fade, ones)).clamp_(0, 1)

1124
1125
1126
    def _fade_out(self, waveform_length: int, device: torch.device) -> Tensor:
        fade = torch.linspace(0, 1, self.fade_out_len, device=device)
        ones = torch.ones(waveform_length - self.fade_out_len, device=device)
Tomás Osório's avatar
Tomás Osório committed
1127
1128

        if self.fade_shape == "linear":
1129
            fade = -fade + 1
Tomás Osório's avatar
Tomás Osório committed
1130
1131

        if self.fade_shape == "exponential":
1132
            fade = torch.pow(2, -fade) * (1 - fade)
Tomás Osório's avatar
Tomás Osório committed
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return torch.cat((ones, fade)).clamp_(0, 1)


1146
class _AxisMasking(torch.nn.Module):
1147
1148
    r"""Apply masking to a spectrogram.

1149
    Args:
1150
1151
1152
        mask_param (int): Maximum possible length of the mask.
        axis (int): What dimension the mask is applied on.
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
1153
            This option is applicable only when the input tensor is 4D.
1154
        p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
1155
    """
1156
    __constants__ = ["mask_param", "axis", "iid_masks", "p"]
1157

1158
    def __init__(self, mask_param: int, axis: int, iid_masks: bool, p: float = 1.0) -> None:
1159
1160
1161
1162
1163

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks
1164
        self.p = p
1165

1166
    def forward(self, specgram: Tensor, mask_value: float = 0.0) -> Tensor:
1167
1168
        r"""
        Args:
1169
            specgram (Tensor): Tensor of dimension `(..., freq, time)`.
1170
            mask_value (float): Value to assign to the masked columns.
1171
1172

        Returns:
1173
            Tensor: Masked spectrogram of dimensions `(..., freq, time)`.
1174
1175
1176
        """
        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
1177
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1, p=self.p)
1178
        else:
1179
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis, p=self.p)
1180
1181
1182


class FrequencyMasking(_AxisMasking):
1183
1184
    r"""Apply masking to a spectrogram in the frequency domain.

moto's avatar
moto committed
1185
1186
1187
1188
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

moto's avatar
moto committed
1189
1190
    Proposed in *SpecAugment* [:footcite:`specaugment`].

1191
1192
1193
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
1194
1195
1196
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
moto's avatar
moto committed
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209

    Example
        >>> spectrogram = torchaudio.transforms.Spectrogram()
        >>> masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=80)
        >>>
        >>> original = spectrogram(waveform)
        >>> masked = masking(original)

        .. image::  https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking1.png
           :alt: The original spectrogram

        .. image::  https://download.pytorch.org/torchaudio/doc-assets/specaugment_freq_masking2.png
           :alt: The spectrogram masked along frequency axis
1210
    """
1211

1212
    def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
1213
1214
1215
1216
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
1217
1218
    r"""Apply masking to a spectrogram in the time domain.

moto's avatar
moto committed
1219
1220
1221
1222
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

moto's avatar
moto committed
1223
1224
    Proposed in *SpecAugment* [:footcite:`specaugment`].

1225
1226
1227
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
1228
1229
1230
        iid_masks (bool, optional): whether to apply different masks to each
            example/channel in the batch. (Default: ``False``)
            This option is applicable only when the input tensor is 4D.
1231
1232
        p (float, optional): maximum proportion of time steps that can be masked.
            Must be within range [0.0, 1.0]. (Default: 1.0)
moto's avatar
moto committed
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245

    Example
        >>> spectrogram = torchaudio.transforms.Spectrogram()
        >>> masking = torchaudio.transforms.TimeMasking(time_mask_param=80)
        >>>
        >>> original = spectrogram(waveform)
        >>> masked = masking(original)

        .. image::  https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking1.png
           :alt: The original spectrogram

        .. image::  https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_masking2.png
           :alt: The spectrogram masked along time axis
1246
    """
1247

1248
1249
1250
1251
    def __init__(self, time_mask_param: int, iid_masks: bool = False, p: float = 1.0) -> None:
        if not 0.0 <= p <= 1.0:
            raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p)
Tomás Osório's avatar
Tomás Osório committed
1252
1253
1254


class Vol(torch.nn.Module):
1255
    r"""Adjust volume of waveform.
Tomás Osório's avatar
Tomás Osório committed
1256

moto's avatar
moto committed
1257
1258
1259
1260
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

Tomás Osório's avatar
Tomás Osório committed
1261
1262
    Args:
        gain (float): Interpreted according to the given gain_type:
Vincent QB's avatar
Vincent QB committed
1263
1264
1265
1266
            If ``gain_type`` = ``amplitude``, ``gain`` is a positive amplitude ratio.
            If ``gain_type`` = ``power``, ``gain`` is a power (voltage squared).
            If ``gain_type`` = ``db``, ``gain`` is in decibels.
        gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
1267
1268
1269
1270
1271

    Example
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.Vol(gain=0.5, gain_type="amplitude")
        >>> quieter_waveform = transform(waveform)
Tomás Osório's avatar
Tomás Osório committed
1272
1273
    """

1274
    def __init__(self, gain: float, gain_type: str = "amplitude"):
Tomás Osório's avatar
Tomás Osório committed
1275
1276
1277
1278
        super(Vol, self).__init__()
        self.gain = gain
        self.gain_type = gain_type

1279
        if gain_type in ["amplitude", "power"] and gain < 0:
Tomás Osório's avatar
Tomás Osório committed
1280
1281
            raise ValueError("If gain_type = amplitude or power, gain must be positive.")

1282
    def forward(self, waveform: Tensor) -> Tensor:
Tomás Osório's avatar
Tomás Osório committed
1283
1284
        r"""
        Args:
1285
            waveform (Tensor): Tensor of audio of dimension `(..., time)`.
Tomás Osório's avatar
Tomás Osório committed
1286
1287

        Returns:
1288
            Tensor: Tensor of audio of dimension `(..., time)`.
Tomás Osório's avatar
Tomás Osório committed
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        """
        if self.gain_type == "amplitude":
            waveform = waveform * self.gain

        if self.gain_type == "db":
            waveform = F.gain(waveform, self.gain)

        if self.gain_type == "power":
            waveform = F.gain(waveform, 10 * math.log10(self.gain))

        return torch.clamp(waveform, -1, 1)
wanglong001's avatar
wanglong001 committed
1300
1301
1302
1303
1304
1305


class SlidingWindowCmn(torch.nn.Module):
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

moto's avatar
moto committed
1306
1307
1308
1309
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

wanglong001's avatar
wanglong001 committed
1310
1311
1312
1313
1314
1315
1316
1317
1318
    Args:
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
    """

1319
1320
1321
    def __init__(
        self, cmn_window: int = 600, min_cmn_window: int = 100, center: bool = False, norm_vars: bool = False
    ) -> None:
wanglong001's avatar
wanglong001 committed
1322
1323
1324
1325
1326
1327
        super().__init__()
        self.cmn_window = cmn_window
        self.min_cmn_window = min_cmn_window
        self.center = center
        self.norm_vars = norm_vars

1328
    def forward(self, specgram: Tensor) -> Tensor:
wanglong001's avatar
wanglong001 committed
1329
1330
        r"""
        Args:
1331
            specgram (Tensor): Tensor of spectrogram of dimension `(..., time, freq)`.
wanglong001's avatar
wanglong001 committed
1332
1333

        Returns:
1334
            Tensor: Tensor of spectrogram of dimension `(..., time, freq)`.
wanglong001's avatar
wanglong001 committed
1335
        """
1336
        cmn_specgram = F.sliding_window_cmn(specgram, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
1337
        return cmn_specgram
Artyom Astafurov's avatar
Artyom Astafurov committed
1338
1339
1340
1341


class Vad(torch.nn.Module):
    r"""Voice Activity Detector. Similar to SoX implementation.
moto's avatar
moto committed
1342
1343
1344
1345
1346

    .. devices:: CPU CUDA

    .. properties:: TorchScript

Artyom Astafurov's avatar
Artyom Astafurov committed
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
    Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
    The algorithm currently uses a simple cepstral power measurement to detect voice,
    so may be fooled by other things, especially music.

    The effect can trim only from the front of the audio,
    so in order to trim from the back, the reverse effect must also be used.

    Args:
        sample_rate (int): Sample rate of audio signal.
        trigger_level (float, optional): The measurement level used to trigger activity detection.
            This may need to be cahnged depending on the noise level, signal level,
            and other characteristics of the input audio. (Default: 7.0)
        trigger_time (float, optional): The time constant (in seconds)
            used to help ignore short bursts of sound. (Default: 0.25)
        search_time (float, optional): The amount of audio (in seconds)
            to search for quieter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 1.0)
        allowed_gap (float, optional): The allowed gap (in seconds) between
            quiteter/shorter bursts of audio to include prior
            to the detected trigger point. (Default: 0.25)
        pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
            before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
        boot_time (float, optional) The algorithm (internally) uses adaptive noise
            estimation/reduction in order to detect the start of the wanted audio.
            This option sets the time for the initial noise estimate. (Default: 0.35)
        noise_up_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is increasing. (Default: 0.1)
        noise_down_time (float, optional) Time constant used by the adaptive noise estimator
            for when the noise level is decreasing. (Default: 0.01)
        noise_reduction_amount (float, optional) Amount of noise reduction to use in
            the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
        measure_freq (float, optional) Frequency of the algorithm’s
            processing/measurements. (Default: 20.0)
1380
        measure_duration: (float or None, optional) Measurement duration.
Artyom Astafurov's avatar
Artyom Astafurov committed
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
            (Default: Twice the measurement period; i.e. with overlap.)
        measure_smooth_time (float, optional) Time constant used to smooth
            spectral measurements. (Default: 0.4)
        hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
            at the input to the detector algorithm. (Default: 50.0)
        lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
            at the input to the detector algorithm. (Default: 6000.0)
        hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
            in the detector algorithm. (Default: 150.0)
        lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
            in the detector algorithm. (Default: 2000.0)

moto's avatar
moto committed
1393
1394
    Reference:
        - http://sox.sourceforge.net/sox.html
Artyom Astafurov's avatar
Artyom Astafurov committed
1395
1396
    """

1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
    def __init__(
        self,
        sample_rate: int,
        trigger_level: float = 7.0,
        trigger_time: float = 0.25,
        search_time: float = 1.0,
        allowed_gap: float = 0.25,
        pre_trigger_time: float = 0.0,
        boot_time: float = 0.35,
        noise_up_time: float = 0.1,
        noise_down_time: float = 0.01,
        noise_reduction_amount: float = 1.35,
        measure_freq: float = 20.0,
        measure_duration: Optional[float] = None,
        measure_smooth_time: float = 0.4,
        hp_filter_freq: float = 50.0,
        lp_filter_freq: float = 6000.0,
        hp_lifter_freq: float = 150.0,
        lp_lifter_freq: float = 2000.0,
    ) -> None:
Artyom Astafurov's avatar
Artyom Astafurov committed
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
        super().__init__()

        self.sample_rate = sample_rate
        self.trigger_level = trigger_level
        self.trigger_time = trigger_time
        self.search_time = search_time
        self.allowed_gap = allowed_gap
        self.pre_trigger_time = pre_trigger_time
        self.boot_time = boot_time
        self.noise_up_time = noise_up_time
1427
        self.noise_down_time = noise_down_time
Artyom Astafurov's avatar
Artyom Astafurov committed
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
        self.noise_reduction_amount = noise_reduction_amount
        self.measure_freq = measure_freq
        self.measure_duration = measure_duration
        self.measure_smooth_time = measure_smooth_time
        self.hp_filter_freq = hp_filter_freq
        self.lp_filter_freq = lp_filter_freq
        self.hp_lifter_freq = hp_lifter_freq
        self.lp_lifter_freq = lp_lifter_freq

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
1440
1441
1442
1443
            waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
                Tensor of shape `(channels, time)` is treated as a multi-channel recording
                of the same event and the resulting output will be trimmed to the earliest
                voice activity in any channel.
Artyom Astafurov's avatar
Artyom Astafurov committed
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
        """
        return F.vad(
            waveform=waveform,
            sample_rate=self.sample_rate,
            trigger_level=self.trigger_level,
            trigger_time=self.trigger_time,
            search_time=self.search_time,
            allowed_gap=self.allowed_gap,
            pre_trigger_time=self.pre_trigger_time,
            boot_time=self.boot_time,
            noise_up_time=self.noise_up_time,
1455
            noise_down_time=self.noise_down_time,
Artyom Astafurov's avatar
Artyom Astafurov committed
1456
1457
1458
1459
1460
1461
1462
1463
1464
            noise_reduction_amount=self.noise_reduction_amount,
            measure_freq=self.measure_freq,
            measure_duration=self.measure_duration,
            measure_smooth_time=self.measure_smooth_time,
            hp_filter_freq=self.hp_filter_freq,
            lp_filter_freq=self.lp_filter_freq,
            hp_lifter_freq=self.hp_lifter_freq,
            lp_lifter_freq=self.lp_lifter_freq,
        )
1465
1466
1467
1468
1469


class SpectralCentroid(torch.nn.Module):
    r"""Compute the spectral centroid for each channel along the time axis.

moto's avatar
moto committed
1470
1471
1472
1473
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

1474
1475
1476
1477
1478
1479
1480
1481
1482
    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        sample_rate (int): Sample rate of audio signal.
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
        win_length (int or None, optional): Window size. (Default: ``n_fft``)
        hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
        pad (int, optional): Two sided padding of signal. (Default: ``0``)
1483
1484
1485
        window_fn (Callable[..., Tensor], optional): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
1486
1487

    Example
nateanl's avatar
nateanl committed
1488
1489
1490
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.SpectralCentroid(sample_rate)
        >>> spectral_centroid = transform(waveform)  # (channel, time)
1491
    """
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
    __constants__ = ["sample_rate", "n_fft", "win_length", "hop_length", "pad"]

    def __init__(
        self,
        sample_rate: int,
        n_fft: int = 400,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        pad: int = 0,
        window_fn: Callable[..., Tensor] = torch.hann_window,
        wkwargs: Optional[dict] = None,
    ) -> None:
1504
1505
1506
1507
1508
        super(SpectralCentroid, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
1509
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
1510
        self.register_buffer("window", window)
1511
1512
1513
1514
1515
        self.pad = pad

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
1516
            waveform (Tensor): Tensor of audio of dimension `(..., time)`.
1517
1518

        Returns:
1519
            Tensor: Spectral Centroid of size `(..., time)`.
1520
1521
        """

1522
1523
1524
        return F.spectral_centroid(
            waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, self.win_length
        )
1525
1526


1527
class PitchShift(LazyModuleMixin, torch.nn.Module):
1528
1529
    r"""Shift the pitch of a waveform by ``n_steps`` steps.

moto's avatar
moto committed
1530
1531
1532
1533
    .. devices:: CPU CUDA

    .. properties:: TorchScript

1534
1535
    Args:
        waveform (Tensor): The input waveform of shape `(..., time)`.
nateanl's avatar
nateanl committed
1536
        sample_rate (int): Sample rate of `waveform`.
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
        n_steps (int): The (fractional) steps to shift `waveform`.
        bins_per_octave (int, optional): The number of steps per octave (Default : ``12``).
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
        win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
        hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4``
            is used (Default: ``None``).
        window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
            If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).

    Example
nateanl's avatar
nateanl committed
1547
1548
1549
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
        >>> transform = transforms.PitchShift(sample_rate, 4)
        >>> waveform_shift = transform(waveform)  # (channel, time)
1550
    """
1551
1552
    __constants__ = ["sample_rate", "n_steps", "bins_per_octave", "n_fft", "win_length", "hop_length"]

1553
1554
1555
    kernel: UninitializedParameter
    width: int

1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
    def __init__(
        self,
        sample_rate: int,
        n_steps: int,
        bins_per_octave: int = 12,
        n_fft: int = 512,
        win_length: Optional[int] = None,
        hop_length: Optional[int] = None,
        window_fn: Callable[..., Tensor] = torch.hann_window,
        wkwargs: Optional[dict] = None,
    ) -> None:
1567
        super().__init__()
1568
1569
1570
1571
1572
1573
1574
        self.n_steps = n_steps
        self.bins_per_octave = bins_per_octave
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 4
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
1575
        self.register_buffer("window", window)
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
        rate = 2.0 ** (-float(n_steps) / bins_per_octave)
        self.orig_freq = int(sample_rate / rate)
        self.gcd = math.gcd(int(self.orig_freq), int(sample_rate))

        if self.orig_freq != sample_rate:
            self.width = -1
            self.kernel = UninitializedParameter(device=None, dtype=None)

    def initialize_parameters(self, input):
        if self.has_uninitialized_params():
            if self.orig_freq != self.sample_rate:
                with torch.no_grad():
                    kernel, self.width = _get_sinc_resample_kernel(
                        self.orig_freq,
                        self.sample_rate,
                        self.gcd,
                        dtype=input.dtype,
                        device=input.device,
                    )
                    self.kernel.materialize(kernel.shape)
                    self.kernel.copy_(kernel)
1597
1598
1599
1600

    def forward(self, waveform: Tensor) -> Tensor:
        r"""
        Args:
1601
            waveform (Tensor): Tensor of audio of dimension `(..., time)`.
1602
1603
1604
1605

        Returns:
            Tensor: The pitch-shifted audio of shape `(..., time)`.
        """
1606
        shape = waveform.size()
1607

1608
        waveform_stretch = _stretch_waveform(
1609
1610
1611
1612
1613
1614
1615
1616
            waveform,
            self.n_steps,
            self.bins_per_octave,
            self.n_fft,
            self.win_length,
            self.hop_length,
            self.window,
        )
1617

1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
        if self.orig_freq != self.sample_rate:
            waveform_shift = _apply_sinc_resample_kernel(
                waveform_stretch,
                self.orig_freq,
                self.sample_rate,
                self.gcd,
                self.kernel,
                self.width,
            )
        else:
            waveform_shift = waveform_stretch

        return _fix_waveform_shape(
            waveform_shift,
            shape,
        )

1635
1636
1637
1638

class RNNTLoss(torch.nn.Module):
    """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
    [:footcite:`graves2012sequence`].
moto's avatar
moto committed
1639
1640
1641
1642
1643

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

1644
1645
1646
1647
1648
1649
1650
1651
1652
    The RNN Transducer loss extends the CTC loss by defining a distribution over output
    sequences of all lengths, and by jointly modelling both input-output and output-output
    dependencies.

    Args:
        blank (int, optional): blank label (Default: ``-1``)
        clamp (float, optional): clamp for gradients (Default: ``-1``)
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669

    Example
        >>> # Hypothetical values
        >>> logits = torch.tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],
        >>>                          [0.1, 0.1, 0.6, 0.1, 0.1],
        >>>                          [0.1, 0.1, 0.2, 0.8, 0.1]],
        >>>                         [[0.1, 0.6, 0.1, 0.1, 0.1],
        >>>                          [0.1, 0.1, 0.2, 0.1, 0.1],
        >>>                          [0.7, 0.1, 0.2, 0.1, 0.1]]]],
        >>>                       dtype=torch.float32,
        >>>                       requires_grad=True)
        >>> targets = torch.tensor([[1, 2]], dtype=torch.int)
        >>> logit_lengths = torch.tensor([2], dtype=torch.int)
        >>> target_lengths = torch.tensor([2], dtype=torch.int)
        >>> transform = transforms.RNNTLoss(blank=0)
        >>> loss = transform(logits, targets, logit_lengths, target_lengths)
        >>> loss.backward()
1670
1671
1672
1673
1674
    """

    def __init__(
        self,
        blank: int = -1,
1675
        clamp: float = -1.0,
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
        reduction: str = "mean",
    ):
        super().__init__()
        self.blank = blank
        self.clamp = clamp
        self.reduction = reduction

    def forward(
        self,
        logits: Tensor,
        targets: Tensor,
        logit_lengths: Tensor,
        target_lengths: Tensor,
    ):
        """
        Args:
1692
            logits (Tensor): Tensor of dimension `(batch, max seq length, max target length + 1, class)`
1693
                containing output from joiner
1694
1695
1696
            targets (Tensor): Tensor of dimension `(batch, max target length)` containing targets with zero padded
            logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder
            target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence
1697
1698
1699
1700
        Returns:
            Tensor: Loss with the reduction option applied. If ``reduction`` is  ``'none'``, then size (batch),
            otherwise scalar.
        """
1701
        return F.rnnt_loss(logits, targets, logit_lengths, target_lengths, self.blank, self.clamp, self.reduction)