transforms.py 27.9 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
from __future__ import absolute_import, division, print_function, unicode_literals
4
from warnings import warn
5
import math
David Pollack's avatar
David Pollack committed
6
import torch
7
from typing import Optional
Jason Lian's avatar
pre  
Jason Lian committed
8
from . import functional as F
jamarshon's avatar
jamarshon committed
9
from .compliance import kaldi
10

Jason Lian's avatar
Jason Lian committed
11

12
13
__all__ = [
    'Spectrogram',
14
    'GriffinLim',
15
    'AmplitudeToDB',
16
    'MelScale',
moto's avatar
moto committed
17
    'InverseMelScale',
18
19
20
21
22
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
23
24
25
26
    'ComplexNorm',
    'TimeStretch',
    'FrequencyMasking',
    'TimeMasking',
27
28
29
]


30
class Spectrogram(torch.nn.Module):
31
    r"""Create a spectrogram from a audio signal
32
33

    Args:
34
35
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        win_length (int): Window size. (Default: ``n_fft``)
36
        hop_length (int, optional): Length of hop between STFT windows. (
37
38
            Default: ``win_length // 2``)
        pad (int): Two sided padding of signal. (Default: ``0``)
39
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
40
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
41
        power (float): Exponent for the magnitude spectrogram,
42
43
44
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
45
    """
46
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
47

48
49
    def __init__(self, n_fft=400, win_length=None, hop_length=None,
                 pad=0, window_fn=torch.hann_window,
50
                 power=2., normalized=False, wkwargs=None):
51
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
52
        self.n_fft = n_fft
53
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
54
        # number of frequecies due to onesided=True in torch.stft
55
56
57
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
58
        self.register_buffer('window', window)
59
        self.pad = pad
PCerles's avatar
PCerles committed
60
        self.power = power
61
        self.normalized = normalized
62

63
64
    def forward(self, waveform):
        r"""
65
        Args:
Vincent QB's avatar
Vincent QB committed
66
            waveform (torch.Tensor): Tensor of audio of dimension (..., time)
67
68

        Returns:
Vincent QB's avatar
Vincent QB committed
69
70
            torch.Tensor: Dimension (..., freq, time), where freq is
            ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
71
            Fourier bins, and time is the number of window hops (n_frame).
72
        """
73
74
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
75
76


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class GriffinLim(torch.nn.Module):
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

    .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

    .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.

    .. [3] D. W. Griffin and J. S. Lim,
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        n_iter (int, optional): Number of iteration for phase recovery process.
        win_length (int): Window size. (Default: ``n_fft``)
        hop_length (int, optional): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
102
        power (float): Exponent for the magnitude spectrogram,
103
104
105
106
107
108
109
110
111
112
113
114
115
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
        length (int, optional): Array length of the expected output. (Default: ``None``)
        rand_init (bool): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
    """
    __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
                     'length', 'momentum', 'rand_init']

    def __init__(self, n_fft=400, n_iter=32, win_length=None, hop_length=None,
116
                 window_fn=torch.hann_window, power=2., normalized=False, wkwargs=None,
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
                 momentum=0.99, length=None, rand_init=True):
        super(GriffinLim, self).__init__()

        assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
        assert momentum > 0, 'momentum=%s < 0' % momentum

        self.n_fft = n_fft
        self.n_iter = n_iter
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
        self.register_buffer('window', window)
        self.normalized = normalized
        self.length = length
        self.power = power
        self.momentum = momentum / (1 + momentum)
        self.rand_init = rand_init

    def forward(self, S):
        return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
                            self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)


140
class AmplitudeToDB(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
141
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
142

143
    This output depends on the maximum value in the input tensor, and so
144
145
146
147
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
148
        stype (str): scale of input tensor ('power' or 'magnitude'). The
149
            power being the elementwise square of the magnitude. (Default: ``'power'``)
150
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
151
            is 80. (Default: ``None``)
152
153
154
155
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

    def __init__(self, stype='power', top_db=None):
156
        super(AmplitudeToDB, self).__init__()
157
        self.stype = stype
158
159
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
160
        self.top_db = top_db
161
162
163
164
165
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

166
    def forward(self, x):
167
168
169
170
        r"""Numerically stable implementation from Librosa
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
171
            x (torch.Tensor): Input tensor before being converted to decibel scale
172
173

        Returns:
174
            torch.Tensor: Output tensor in decibel scale
175
        """
176
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
177
178


179
class MelScale(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
180
    r"""Turn a normal STFT into a mel frequency STFT, using a conversion
181
    matrix.  This uses triangular filter banks.
182

183
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
184

185
    Args:
186
187
188
189
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
190
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
191
            if None is given.  See ``n_fft`` in :class:`Spectrogram`.
192
    """
193
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
194

195
    def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
196
        super(MelScale, self).__init__()
197
        self.n_mels = n_mels
198
199
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
200
        self.f_min = f_min
201
202
203

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

204
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
engineerchuan's avatar
engineerchuan committed
205
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
206
        self.register_buffer('fb', fb)
207

208
209
210
    def forward(self, specgram):
        r"""
        Args:
Vincent QB's avatar
Vincent QB committed
211
            specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time)
212
213

        Returns:
Vincent QB's avatar
Vincent QB committed
214
            torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
215
        """
Vincent QB's avatar
Vincent QB committed
216
217
218

        # pack batch
        shape = specgram.size()
Vincent QB's avatar
Vincent QB committed
219
        specgram = specgram.view(-1, shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
220

221
        if self.fb.numel() == 0:
engineerchuan's avatar
engineerchuan committed
222
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
223
224
225
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
226

227
228
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
229
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
230
231

        # unpack batch
Vincent QB's avatar
Vincent QB committed
232
        mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
Vincent QB's avatar
Vincent QB committed
233

234
        return mel_specgram
235

236

moto's avatar
moto committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class InverseMelScale(torch.nn.Module):
    r"""Solve for a normal STFT from a mel frequency STFT, using a conversion
    matrix.  This uses triangular filter banks.

    It minimizes the euclidian norm between the input mel-spectrogram and the product between
    the estimated spectrogram and the filter banks using SGD.

    Args:
        n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
        max_iter (int): Maximum number of optimization iterations.
        tolerance_loss (float): Value of loss to stop optimization at.
        tolerance_change (float): Difference in losses to stop optimization at.
        sgdargs (dict): Arguments for the SGD optimizer.
    """
    __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
                     'tolerance_change', 'sgdargs']

    def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000,
                 tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None):
        super(InverseMelScale, self).__init__()
        self.n_mels = n_mels
        self.sample_rate = sample_rate
        self.f_max = f_max or float(sample_rate // 2)
        self.f_min = f_min
        self.max_iter = max_iter
        self.tolerance_loss = tolerance_loss
        self.tolerance_change = tolerance_change
        self.sgdargs = sgdargs or {'lr': 0.1, 'momentum': 0.9}

        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)

        fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
        self.register_buffer('fb', fb)

    def forward(self, melspec):
        r"""
        Args:
            melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)

        Returns:
            torch.Tensor: Linear scale spectrogram of size (..., freq, time)
        """
        # pack batch
        shape = melspec.size()
        melspec = melspec.view(-1, shape[-2], shape[-1])

        n_mels, time = shape[-2], shape[-1]
        freq, _ = self.fb.size()  # (freq, n_mels)
        melspec = melspec.transpose(-1, -2)
        assert self.n_mels == n_mels

        specgram = torch.rand(melspec.size()[0], time, freq, requires_grad=True,
                              dtype=melspec.dtype, device=melspec.device)

        optim = torch.optim.SGD([specgram], **self.sgdargs)

        loss = float('inf')
        for _ in range(self.max_iter):
            optim.zero_grad()
            diff = melspec - specgram.matmul(self.fb)
            new_loss = diff.pow(2).sum(axis=-1).mean()
            # take sum over mel-frequency then average over other dimensions
            # so that loss threshold is applied par unit timeframe
            new_loss.backward()
            optim.step()
            specgram.data = specgram.data.clamp(min=0)

            new_loss = new_loss.item()
            if new_loss < self.tolerance_loss or abs(loss - new_loss) < self.tolerance_change:
                break
            loss = new_loss

        specgram.requires_grad_(False)
        specgram = specgram.clamp(min=0).transpose(-1, -2)

        # unpack batch
        specgram = specgram.view(shape[:-2] + (freq, time))
        return specgram


321
class MelSpectrogram(torch.nn.Module):
322
323
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
324

325
    Sources
326
327
328
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
329

330
    Args:
331
332
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        win_length (int): Window size. (Default: ``n_fft``)
333
        hop_length (int, optional): Length of hop between STFT windows. (
334
335
336
337
338
339
            Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``None``)
        pad (int): Two sided padding of signal. (Default: ``0``)
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
340
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
341
342
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
343

344
    Example
345
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
346
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
347
    """
348
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
349

350
351
352
353
354
355
356
357
358
    def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
359
        self.f_max = f_max
360
361
362
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
363
                                       pad=self.pad, window_fn=window_fn, power=2.,
364
                                       normalized=False, wkwargs=wkwargs)
365
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
366

367
368
    def forward(self, waveform):
        r"""
369
        Args:
Vincent QB's avatar
Vincent QB committed
370
            waveform (torch.Tensor): Tensor of audio of dimension (..., time)
371
372

        Returns:
Vincent QB's avatar
Vincent QB committed
373
            torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
374
        """
375
376
377
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
378
379


380
class MFCC(torch.nn.Module):
381
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal
PCerles's avatar
PCerles committed
382

383
384
385
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
386

387
388
389
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
390

391
    Args:
392
393
394
395
396
397
398
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``)
        norm (str, optional): norm to use. (Default: ``'ortho'``)
        log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default:
            ``False``)
        melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
399
    """
400
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
401

402
    def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
PCerles's avatar
PCerles committed
403
                 melkwargs=None):
404
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
405
406
407
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
408
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
409
410
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
411
        self.norm = norm
412
        self.top_db = 80.0
413
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
414
415

        if melkwargs is not None:
416
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
417
        else:
418
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
419
420
421

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
422
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
423
        self.register_buffer('dct_mat', dct_mat)
PCerles's avatar
PCerles committed
424
425
        self.log_mels = log_mels

426
427
    def forward(self, waveform):
        r"""
PCerles's avatar
PCerles committed
428
        Args:
Vincent QB's avatar
Vincent QB committed
429
            waveform (torch.Tensor): Tensor of audio of dimension (..., time)
PCerles's avatar
PCerles committed
430
431

        Returns:
Vincent QB's avatar
Vincent QB committed
432
            torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time)
PCerles's avatar
PCerles committed
433
        """
Vincent QB's avatar
Vincent QB committed
434
435
436

        # pack batch
        shape = waveform.size()
Vincent QB's avatar
Vincent QB committed
437
        waveform = waveform.view(-1, shape[-1])
Vincent QB's avatar
Vincent QB committed
438

439
        mel_specgram = self.MelSpectrogram(waveform)
440
441
        if self.log_mels:
            log_offset = 1e-6
442
            mel_specgram = torch.log(mel_specgram + log_offset)
443
        else:
444
            mel_specgram = self.amplitude_to_DB(mel_specgram)
445
446
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
447
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
Vincent QB's avatar
Vincent QB committed
448
449

        # unpack batch
Vincent QB's avatar
Vincent QB committed
450
        mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
Vincent QB's avatar
Vincent QB committed
451

452
        return mfcc
453
454


455
class MuLawEncoding(torch.nn.Module):
456
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
457
458
459
460
461
462
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
463
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
464
    """
465
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
466
467

    def __init__(self, quantization_channels=256):
468
        super(MuLawEncoding, self).__init__()
469
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
470

471
    def forward(self, x):
472
        r"""
David Pollack's avatar
David Pollack committed
473
        Args:
474
            x (torch.Tensor): A signal to be encoded
David Pollack's avatar
David Pollack committed
475
476

        Returns:
477
            x_mu (torch.Tensor): An encoded signal
David Pollack's avatar
David Pollack committed
478
        """
479
        return F.mu_law_encoding(x, self.quantization_channels)
480

Soumith Chintala's avatar
Soumith Chintala committed
481

482
class MuLawDecoding(torch.nn.Module):
483
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
484
485
486
487
488
489
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
490
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
491
    """
492
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
493
494

    def __init__(self, quantization_channels=256):
495
        super(MuLawDecoding, self).__init__()
496
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
497

498
    def forward(self, x_mu):
499
        r"""
David Pollack's avatar
David Pollack committed
500
        Args:
501
            x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
David Pollack's avatar
David Pollack committed
502
503

        Returns:
504
            torch.Tensor: The signal decoded
David Pollack's avatar
David Pollack committed
505
        """
506
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
507
508
509


class Resample(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
510
    r"""Resample a signal from one frequency to another. A resampling method can
jamarshon's avatar
jamarshon committed
511
512
513
    be given.

    Args:
514
515
516
        orig_freq (float): The original frequency of the signal. (Default: ``16000``)
        new_freq (float): The desired frequency. (Default: ``16000``)
        resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
517
    """
518
    def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
jamarshon's avatar
jamarshon committed
519
520
521
522
523
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

524
525
    def forward(self, waveform):
        r"""
jamarshon's avatar
jamarshon committed
526
        Args:
Vincent QB's avatar
Vincent QB committed
527
            waveform (torch.Tensor): The input signal of dimension (..., time)
jamarshon's avatar
jamarshon committed
528
529

        Returns:
Vincent QB's avatar
Vincent QB committed
530
            torch.Tensor: Output signal of dimension (..., time)
jamarshon's avatar
jamarshon committed
531
532
        """
        if self.resampling_method == 'sinc_interpolation':
Vincent QB's avatar
Vincent QB committed
533
534
535
536
537
538
539
540
541
542
543

            # pack batch
            shape = waveform.size()
            waveform = waveform.view(-1, shape[-1])

            waveform = kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)

            # unpack batch
            waveform = waveform.view(shape[:-1] + waveform.shape[-1:])

            return waveform
jamarshon's avatar
jamarshon committed
544
545

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
546
547


548
class ComplexNorm(torch.nn.Module):
549
550
551
552
553
554
555
556
557
558
559
560
561
    r"""Compute the norm of complex tensor input
    Args:
        power (float): Power of the norm. Defaults to `1.0`.
    """
    __constants__ = ['power']

    def __init__(self, power=1.0):
        super(ComplexNorm, self).__init__()
        self.power = power

    def forward(self, complex_tensor):
        r"""
        Args:
562
            complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
563
        Returns:
564
            Tensor: norm of the input tensor, shape of `(..., )`
565
566
567
568
        """
        return F.complex_norm(complex_tensor, self.power)


569
class ComputeDeltas(torch.nn.Module):
Vincent QB's avatar
Vincent QB committed
570
571
572
573
574
575
576
577
578
579
580
581
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
        win_length (int): The window length used for computing delta.
    """
    __constants__ = ['win_length']

    def __init__(self, win_length=5, mode="replicate"):
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
582
        self.mode = mode
Vincent QB's avatar
Vincent QB committed
583
584
585
586

    def forward(self, specgram):
        r"""
        Args:
Vincent QB's avatar
Vincent QB committed
587
            specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
588
589

        Returns:
Vincent QB's avatar
Vincent QB committed
590
            deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
591
592
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
593
594


595
class TimeStretch(torch.nn.Module):
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
    r"""Stretch stft in time without modifying pitch for a given rate.

    Args:
        hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
        n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
        fixed_rate (float): rate to speed up or slow down by.
            If None is provided, rate must be passed to the forward method. (Default: ``None``)
    """
    __constants__ = ['fixed_rate']

    def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
        super(TimeStretch, self).__init__()

        self.fixed_rate = fixed_rate

        n_fft = (n_freq - 1) * 2
        hop_length = hop_length if hop_length is not None else n_fft // 2
613
        self.phase_advance = torch.linspace(0, math.pi * hop_length, n_freq)[..., None]
614
615
616
617
618

    def forward(self, complex_specgrams, overriding_rate=None):
        # type: (Tensor, Optional[float]) -> Tensor
        r"""
        Args:
619
            complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
620
621
622
623
            overriding_rate (float or None): speed up to apply to this batch.
                If no rate is passed, use ``self.fixed_rate``

        Returns:
624
            (Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
625
        """
626
        assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
627
628
629
630
631
632
633
634
635
636
637
638

        if overriding_rate is None:
            rate = self.fixed_rate
            if rate is None:
                raise ValueError("If no fixed_rate is specified"
                                 ", must pass a valid rate to the forward method.")
        else:
            rate = overriding_rate

        if rate == 1.0:
            return complex_specgrams

639
        return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
640
641
642


class _AxisMasking(torch.nn.Module):
643
644
    r"""Apply masking to a spectrogram.

645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    Args:
        mask_param (int): Maximum possible length of the mask
        axis: What dimension the mask is applied on
        iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
    """
    __constants__ = ['mask_param', 'axis', 'iid_masks']

    def __init__(self, mask_param, axis, iid_masks):

        super(_AxisMasking, self).__init__()
        self.mask_param = mask_param
        self.axis = axis
        self.iid_masks = iid_masks

    def forward(self, specgram, mask_value=0.):
        # type: (Tensor, float) -> Tensor
        r"""
        Args:
663
            specgram (torch.Tensor): Tensor of dimension (..., freq, time)
664
665

        Returns:
666
            torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
667
668
669
670
671
672
        """

        # if iid_masks flag marked and specgram has a batch dimension
        if self.iid_masks and specgram.dim() == 4:
            return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
        else:
673
            return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
674
675
676


class FrequencyMasking(_AxisMasking):
677
678
    r"""Apply masking to a spectrogram in the frequency domain.

679
680
681
682
683
684
685
686
687
688
689
690
    Args:
        freq_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, freq_mask_param).
        iid_masks (bool): weather to apply the same mask to all
            the examples/channels in the batch. (Default: False)
    """

    def __init__(self, freq_mask_param, iid_masks=False):
        super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)


class TimeMasking(_AxisMasking):
691
692
    r"""Apply masking to a spectrogram in the time domain.

693
694
695
696
697
698
699
700
701
    Args:
        time_mask_param (int): maximum possible length of the mask.
            Indices uniformly sampled from [0, time_mask_param).
        iid_masks (bool): weather to apply the same mask to all
            the examples/channels in the batch. Defaults to False.
    """

    def __init__(self, time_mask_param, iid_masks=False):
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)