transforms.py 17.9 KB
Newer Older
1
from __future__ import division, print_function
2
from warnings import warn
3
import math
David Pollack's avatar
David Pollack committed
4
import torch
5
from typing import Optional
Jason Lian's avatar
pre  
Jason Lian committed
6
from . import functional as F
jamarshon's avatar
jamarshon committed
7
from .compliance import kaldi
8

Jason Lian's avatar
Jason Lian committed
9

10
# TODO remove this class
David Pollack's avatar
David Pollack committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.Scale(),
        >>>     transforms.PadTrim(max_len=16000),
        >>> ])
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, audio):
        for t in self.transforms:
            audio = t(audio)
        return audio

31
32
33
34
35
36
37
38
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

Soumith Chintala's avatar
Soumith Chintala committed
39

40
class Scale(torch.jit.ScriptModule):
David Pollack's avatar
David Pollack committed
41
42
43
44
45
    """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
    to a floating point number between -1.0 and 1.0.  Note the 16-bit number is
    called the "bit depth" or "precision", not to be confused with "bit rate".

    Args:
David Pollack's avatar
David Pollack committed
46
        factor (int): maximum value of input tensor. default: 16-bit depth
David Pollack's avatar
David Pollack committed
47
48

    """
49
    __constants__ = ['factor']
David Pollack's avatar
David Pollack committed
50
51

    def __init__(self, factor=2**31):
52
        super(Scale, self).__init__()
David Pollack's avatar
David Pollack committed
53
54
        self.factor = factor

55
56
    @torch.jit.script_method
    def forward(self, tensor):
David Pollack's avatar
David Pollack committed
57
58
59
60
61
62
63
64
65
        """

        Args:
            tensor (Tensor): Tensor of audio of size (Samples x Channels)

        Returns:
            Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)

        """
Jason Lian's avatar
pre  
Jason Lian committed
66
        return F.scale(tensor, self.factor)
David Pollack's avatar
David Pollack committed
67

68
69
70
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
71

72
class PadTrim(torch.jit.ScriptModule):
Jason Lian's avatar
Jason Lian committed
73
    """Pad/Trim a 2d-Tensor (Signal or Labels)
David Pollack's avatar
David Pollack committed
74

David Pollack's avatar
David Pollack committed
75
    Args:
76
        tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
David Pollack's avatar
David Pollack committed
77
        max_len (int): Length to which the tensor will be padded
78
        channels_first (bool): Pad for channels first tensors.  Default: `True`
David Pollack's avatar
David Pollack committed
79

David Pollack's avatar
David Pollack committed
80
    """
81
    __constants__ = ['max_len', 'fill_value', 'len_dim', 'ch_dim']
David Pollack's avatar
David Pollack committed
82

83
84
    def __init__(self, max_len, fill_value=0., channels_first=True):
        super(PadTrim, self).__init__()
David Pollack's avatar
David Pollack committed
85
86
        self.max_len = max_len
        self.fill_value = fill_value
87
        self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
David Pollack's avatar
David Pollack committed
88

89
90
    @torch.jit.script_method
    def forward(self, tensor):
David Pollack's avatar
David Pollack committed
91
92
93
        """

        Returns:
94
            Tensor: (c x n) or (n x c)
David Pollack's avatar
David Pollack committed
95
96

        """
Jason Lian's avatar
Jason Lian committed
97
        return F.pad_trim(tensor, self.ch_dim, self.max_len, self.len_dim, self.fill_value)
David Pollack's avatar
David Pollack committed
98

99
100
101
    def __repr__(self):
        return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)

David Pollack's avatar
David Pollack committed
102

103
class DownmixMono(torch.jit.ScriptModule):
104
105
    """Downmix any stereo signals to mono.  Consider using a `SoxEffectsChain` with
       the `channels` effect instead of this transformation.
David Pollack's avatar
David Pollack committed
106

David Pollack's avatar
David Pollack committed
107
    Inputs:
108
109
        tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
        channels_first (bool): Downmix across channels dimension.  Default: `True`
David Pollack's avatar
David Pollack committed
110
111
112
113

    Returns:
        tensor (Tensor) (Samples x 1):

David Pollack's avatar
David Pollack committed
114
    """
115
    __constants__ = ['ch_dim']
David Pollack's avatar
David Pollack committed
116

117
    def __init__(self, channels_first=None):
118
        super(DownmixMono, self).__init__()
119
        self.ch_dim = int(not channels_first)
David Pollack's avatar
David Pollack committed
120

121
122
    @torch.jit.script_method
    def forward(self, tensor):
Jason Lian's avatar
Jason Lian committed
123
        return F.downmix_mono(tensor, self.ch_dim)
124

125
126
127
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
128

129
class LC2CL(torch.jit.ScriptModule):
130
    """Permute a 2d tensor from samples (n x c) to (c x n)
131
132
    """

133
134
135
136
137
    def __init__(self):
        super(LC2CL, self).__init__()

    @torch.jit.script_method
    def forward(self, tensor):
138
139
140
        """

        Args:
141
            tensor (Tensor): Tensor of audio signal with shape (LxC)
142
143

        Returns:
144
            tensor (Tensor): Tensor of audio signal with shape (CxL)
145
        """
Jason Lian's avatar
more  
Jason Lian committed
146
        return F.LC2CL(tensor)
147

148
149
150
    def __repr__(self):
        return self.__class__.__name__ + '()'

151

152
153
154
155
156
def SPECTROGRAM(*args, **kwargs):
    warn("SPECTROGRAM has been renamed to Spectrogram")
    return Spectrogram(*args, **kwargs)


157
class Spectrogram(torch.jit.ScriptModule):
158
159
160
    """Create a spectrogram from a raw audio signal

    Args:
PCerles's avatar
PCerles committed
161
162
        n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins
        ws (int): window size. default: n_fft
163
164
165
        hop (int, optional): length of hop between STFT windows. default: ws // 2
        pad (int): two sided padding of signal
        window (torch windowing function): default: torch.hann_window
PCerles's avatar
PCerles committed
166
167
168
        power (int > 0 ) : Exponent for the magnitude spectrogram,
                        e.g., 1 for energy, 2 for power, etc.
        normalize (bool) : whether to normalize by magnitude after stft
169
170
        wkwargs (dict, optional): arguments for window function
    """
171
172
    __constants__ = ['n_fft', 'ws', 'hop', 'pad', 'power', 'normalize']

PCerles's avatar
PCerles committed
173
174
175
    def __init__(self, n_fft=400, ws=None, hop=None,
                 pad=0, window=torch.hann_window,
                 power=2, normalize=False, wkwargs=None):
176
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
177
        self.n_fft = n_fft
178
179
        # number of fft bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequecies due to onesided=True in torch.stft
PCerles's avatar
PCerles committed
180
181
        self.ws = ws if ws is not None else n_fft
        self.hop = hop if hop is not None else self.ws // 2
182
183
        window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs)
        self.window = torch.jit.Attribute(window, torch.Tensor)
184
        self.pad = pad
PCerles's avatar
PCerles committed
185
186
        self.power = power
        self.normalize = normalize
187

188
189
    @torch.jit.script_method
    def forward(self, sig):
190
191
        """
        Args:
192
            sig (Tensor): Tensor of audio of size (c, n)
193
194

        Returns:
195
            spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels
196
197
198
199
200
                is unchanged, hops is the number of hops, and n_fft is the
                number of fourier bins, which should be the window size divided
                by 2 plus 1.

        """
Jason Lian's avatar
Jason Lian committed
201
202
        return F.spectrogram(sig, self.pad, self.window, self.n_fft, self.hop,
                             self.ws, self.power, self.normalize)
203
204


205
206
207
208
209
def F2M(*args, **kwargs):
    warn("F2M has been renamed to MelScale")
    return MelScale(*args, **kwargs)


210
class MelScale(torch.jit.ScriptModule):
211
    """This turns a normal STFT into a mel frequency STFT, using a conversion
212
213
       matrix.  This uses triangular filter banks.

214
215
       User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).

216
    Args:
217
        n_mels (int): number of mel bins
218
        sr (int): sample rate of audio signal
219
        f_max (float, optional): maximum frequency. default: `sr` // 2
220
        f_min (float): minimum frequency. default: 0
221
        n_stft (int, optional): number of filter banks from stft. Calculated from first input
222
            if `None` is given.  See `n_fft` in `Spectrogram`.
223
    """
224
225
    __constants__ = ['n_mels', 'sr', 'f_min', 'f_max']

PCerles's avatar
PCerles committed
226
    def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None):
227
        super(MelScale, self).__init__()
228
229
        self.n_mels = n_mels
        self.sr = sr
230
        self.f_max = f_max if f_max is not None else float(sr // 2)
231
        self.f_min = f_min
232
233
234
235
236
237
238
239
240
241
242
243
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
            n_stft, self.f_min, self.f_max, self.n_mels)
        self.fb = torch.jit.Attribute(fb, torch.Tensor)

    @torch.jit.script_method
    def forward(self, spec_f):
        if self.fb.numel() == 0:
            tmp_fb = F.create_fb_matrix(spec_f.size(2), self.f_min, self.f_max, self.n_mels)
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
        spec_m = torch.matmul(spec_f, self.fb)  # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
244
245
        return spec_m

246

247
class SpectrogramToDB(torch.jit.ScriptModule):
248
249
    """Turns a spectrogram from the power/amplitude scale to the decibel scale.

PCerles's avatar
PCerles committed
250
251
252
253
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.

254
255
256
257
    Args:
        stype (str): scale of input spectrogram ("power" or "magnitude").  The
            power being the elementwise square of the magnitude. default: "power"
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
PCerles's avatar
PCerles committed
258
            is 80.
259
    """
260
261
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

262
    def __init__(self, stype="power", top_db=None):
263
264
        super(SpectrogramToDB, self).__init__()
        self.stype = torch.jit.Attribute(stype, str)
265
        if top_db is not None and top_db < 0:
PCerles's avatar
PCerles committed
266
            raise ValueError('top_db must be positive value')
267
        self.top_db = torch.jit.Attribute(top_db, Optional[float])
268
        self.multiplier = 10. if stype == "power" else 20.
PCerles's avatar
PCerles committed
269
270
        self.amin = 1e-10
        self.ref_value = 1.
271
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))
272

273
274
    @torch.jit.script_method
    def forward(self, spec):
PCerles's avatar
PCerles committed
275
276
        # numerically stable implementation from librosa
        # https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Jason Lian's avatar
more  
Jason Lian committed
277
        return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db)
278
279


280
class MFCC(torch.jit.ScriptModule):
PCerles's avatar
PCerles committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    """Create the Mel-frequency cepstrum coefficients from an audio signal

        By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
        This is not the textbook implementation, but is implemented here to
        give consistency with librosa.

        This output depends on the maximum value in the input spectrogram, and so
        may return different values for an audio clip split into snippets vs. a
        a full clip.

        Args:
        sr (int) : sample rate of audio signal
        n_mfcc (int) : number of mfc coefficients to retain
        dct_type (int) : type of DCT (discrete cosine transform) to use
295
        norm (string, optional) : norm to use
PCerles's avatar
PCerles committed
296
297
298
        log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled
        melkwargs (dict, optional): arguments for MelSpectrogram
    """
299
300
    __constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']

PCerles's avatar
PCerles committed
301
302
    def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
                 melkwargs=None):
303
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
304
305
306
307
308
309
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
        self.sr = sr
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
310
        self.norm = torch.jit.Attribute(norm, Optional[str])
PCerles's avatar
PCerles committed
311
312
313
314
315
316
317
318
319
320
        self.top_db = 80.
        self.s2db = SpectrogramToDB("power", self.top_db)

        if melkwargs is not None:
            self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs)
        else:
            self.MelSpectrogram = MelSpectrogram(sr=self.sr)

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
321
322
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
        self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor)
PCerles's avatar
PCerles committed
323
324
        self.log_mels = log_mels

325
326
    @torch.jit.script_method
    def forward(self, sig):
PCerles's avatar
PCerles committed
327
328
329
330
331
332
333
334
335
        """
        Args:
            sig (Tensor): Tensor of audio of size (channels [c], samples [n])

        Returns:
            spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
                is unchanged, hops is the number of hops, and n_mels is the
                number of mel bins.
        """
336
337
338
339
340
341
342
343
        mel_spect = self.MelSpectrogram(sig)
        if self.log_mels:
            log_offset = 1e-6
            mel_spect = torch.log(mel_spect + log_offset)
        else:
            mel_spect = self.s2db(mel_spect)
        mfcc = torch.matmul(mel_spect, self.dct_mat)
        return mfcc
344
345


346
class MelSpectrogram(torch.jit.ScriptModule):
347
    """Create MEL Spectrograms from a raw audio signal using the stft
348
       function in PyTorch.
349
350
351
352
353
354
355
356

    Sources:
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html

    Args:
        sr (int): sample rate of audio signal
357
358
359
        ws (int): window size
        hop (int, optional): length of hop between STFT windows. default: `ws` // 2
        n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1
360
361
        f_max (float, optional): maximum frequency. default: `sr` // 2
        f_min (float): minimum frequency. default: 0
362
363
        pad (int): two sided padding of signal
        n_mels (int): number of MEL bins
364
        window (torch windowing function): default: `torch.hann_window`
365
366
367
368
        wkwargs (dict, optional): arguments for window function

    Example:
        >>> sig, sr = torchaudio.load("test.wav", normalization=True)
369
        >>> spec_mel = transforms.MelSpectrogram(sr)(sig)  # (c, l, m)
370
    """
371
372
    __constants__ = ['sr', 'n_fft', 'ws', 'hop', 'pad', 'n_mels', 'f_min']

PCerles's avatar
PCerles committed
373
374
    def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window=torch.hann_window, wkwargs=None):
375
        super(MelSpectrogram, self).__init__()
376
        self.sr = sr
PCerles's avatar
PCerles committed
377
378
379
        self.n_fft = n_fft
        self.ws = ws if ws is not None else n_fft
        self.hop = hop if hop is not None else self.ws // 2
380
381
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
382
        self.f_max = torch.jit.Attribute(f_max, Optional[float])
383
        self.f_min = f_min
PCerles's avatar
PCerles committed
384
        self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop,
385
386
                                pad=self.pad, window=window, power=2,
                                normalize=False, wkwargs=wkwargs)
387
        self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min)
388

389
390
    @torch.jit.script_method
    def forward(self, sig):
391
392
393
394
395
        """
        Args:
            sig (Tensor): Tensor of audio of size (channels [c], samples [n])

        Returns:
PCerles's avatar
PCerles committed
396
            spec_mel (Tensor): channels x hops x n_mels (c, l, m), where channels
397
398
399
400
                is unchanged, hops is the number of hops, and n_mels is the
                number of mel bins.

        """
401
402
        spec = self.spec(sig)
        spec_mel = self.fm(spec)
PCerles's avatar
PCerles committed
403
        return spec_mel
404

405
406
407
408
409

def MEL(*args, **kwargs):
    raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa")


410
class BLC2CBL(torch.jit.ScriptModule):
411
412
    """Permute a 3d tensor from Bands x Sample length x Channels to Channels x
       Bands x Samples length
413
414
    """

415
416
417
418
419
    def __init__(self):
        super(BLC2CBL, self).__init__()

    @torch.jit.script_method
    def forward(self, tensor):
420
421
422
423
424
425
426
427
428
        """

        Args:
            tensor (Tensor): Tensor of spectrogram with shape (BxLxC)

        Returns:
            tensor (Tensor): Tensor of spectrogram with shape (CxBxL)

        """
Jason Lian's avatar
more  
Jason Lian committed
429
        return F.BLC2CBL(tensor)
David Pollack's avatar
David Pollack committed
430

431
432
433
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
434

435
class MuLawEncoding(torch.jit.ScriptModule):
David Pollack's avatar
David Pollack committed
436
437
438
439
440
441
442
443
444
445
    """Encode signal based on mu-law companding.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
        quantization_channels (int): Number of channels. default: 256

    """
446
    __constants__ = ['qc']
David Pollack's avatar
David Pollack committed
447
448

    def __init__(self, quantization_channels=256):
449
        super(MuLawEncoding, self).__init__()
David Pollack's avatar
David Pollack committed
450
451
        self.qc = quantization_channels

452
453
    @torch.jit.script_method
    def forward(self, x):
David Pollack's avatar
David Pollack committed
454
455
456
        """

        Args:
457
            x (FloatTensor/LongTensor)
David Pollack's avatar
David Pollack committed
458
459

        Returns:
460
            x_mu (LongTensor)
David Pollack's avatar
David Pollack committed
461
462

        """
Jason Lian's avatar
pre  
Jason Lian committed
463
        return F.mu_law_encoding(x, self.qc)
David Pollack's avatar
David Pollack committed
464

465
466
467
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
468

469
class MuLawExpanding(torch.jit.ScriptModule):
David Pollack's avatar
David Pollack committed
470
471
472
473
474
475
476
477
478
479
    """Decode mu-law encoded signal.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
        quantization_channels (int): Number of channels. default: 256

    """
480
    __constants__ = ['qc']
David Pollack's avatar
David Pollack committed
481
482

    def __init__(self, quantization_channels=256):
483
        super(MuLawExpanding, self).__init__()
David Pollack's avatar
David Pollack committed
484
485
        self.qc = quantization_channels

486
487
    @torch.jit.script_method
    def forward(self, x_mu):
David Pollack's avatar
David Pollack committed
488
489
490
        """

        Args:
Jason Lian's avatar
Jason Lian committed
491
            x_mu (Tensor)
David Pollack's avatar
David Pollack committed
492
493

        Returns:
Jason Lian's avatar
Jason Lian committed
494
            x (Tensor)
David Pollack's avatar
David Pollack committed
495
496

        """
Jason Lian's avatar
pre  
Jason Lian committed
497
        return F.mu_law_expanding(x_mu, self.qc)
498
499
500

    def __repr__(self):
        return self.__class__.__name__ + '()'
jamarshon's avatar
jamarshon committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530


class Resample(torch.nn.Module):
    """Resamples a signal from one frequency to another. A resampling method can
    be given.

    Args:
        orig_freq (float): the original frequency of the signal
        new_freq (float): the desired frequency
        resampling_method (str): the resampling method (Default: 'kaldi' which uses
            sinc interpolation)
    """
    def __init__(self, orig_freq, new_freq, resampling_method='sinc_interpolation'):
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

    def forward(self, sig):
        """
        Args:
            sig (Tensor): the input signal of size (c, n)

        Returns:
            Tensor: output signal of size (c, m)
        """
        if self.resampling_method == 'sinc_interpolation':
            return kaldi.resample_waveform(sig, self.orig_freq, self.new_freq)

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))