transforms.py 16.8 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
from warnings import warn
3
import math
David Pollack's avatar
David Pollack committed
4
import torch
5
from typing import Optional
Jason Lian's avatar
pre  
Jason Lian committed
6
from . import functional as F
jamarshon's avatar
jamarshon committed
7
from .compliance import kaldi
8

Jason Lian's avatar
Jason Lian committed
9

10
11
__all__ = [
    'Spectrogram',
12
    'AmplitudeToDB',
13
14
15
16
17
18
    'MelScale',
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
19
    'ComplexNorm'
20
21
22
]


23
class Spectrogram(torch.jit.ScriptModule):
24
    r"""Create a spectrogram from a audio signal
25
26

    Args:
27
28
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        win_length (int): Window size. (Default: ``n_fft``)
29
        hop_length (int, optional): Length of hop between STFT windows. (
30
31
            Default: ``win_length // 2``)
        pad (int): Two sided padding of signal. (Default: ``0``)
32
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
33
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
34
        power (int): Exponent for the magnitude spectrogram,
35
36
37
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
38
    """
39
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
40

41
42
43
    def __init__(self, n_fft=400, win_length=None, hop_length=None,
                 pad=0, window_fn=torch.hann_window,
                 power=2, normalized=False, wkwargs=None):
44
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
45
        self.n_fft = n_fft
46
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
47
        # number of frequecies due to onesided=True in torch.stft
48
49
50
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
51
        self.window = torch.jit.Attribute(window, torch.Tensor)
52
        self.pad = pad
PCerles's avatar
PCerles committed
53
        self.power = power
54
        self.normalized = normalized
55

56
    @torch.jit.script_method
57
58
    def forward(self, waveform):
        r"""
59
        Args:
60
            waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
61
62

        Returns:
63
64
            torch.Tensor: Dimension (channel, freq, time), where channel
            is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
65
            Fourier bins, and time is the number of window hops (n_frame).
66
        """
67
68
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
69
70


71
72
class AmplitudeToDB(torch.jit.ScriptModule):
    r"""Turns a tensor from the power/amplitude scale to the decibel scale.
73

74
    This output depends on the maximum value in the input tensor, and so
75
76
77
78
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
79
        stype (str): scale of input tensor ('power' or 'magnitude'). The
80
            power being the elementwise square of the magnitude. (Default: ``'power'``)
81
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
82
            is 80. (Default: ``None``)
83
84
85
86
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

    def __init__(self, stype='power', top_db=None):
87
        super(AmplitudeToDB, self).__init__()
88
89
90
91
92
93
94
95
96
97
        self.stype = torch.jit.Attribute(stype, str)
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
        self.top_db = torch.jit.Attribute(top_db, Optional[float])
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

    @torch.jit.script_method
98
    def forward(self, x):
99
100
101
102
        r"""Numerically stable implementation from Librosa
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
103
            x (torch.Tensor): Input tensor before being converted to decibel scale
104
105

        Returns:
106
            torch.Tensor: Output tensor in decibel scale
107
        """
108
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
109
110


111
class MelScale(torch.jit.ScriptModule):
112
    r"""This turns a normal STFT into a mel frequency STFT, using a conversion
113
    matrix.  This uses triangular filter banks.
114

115
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
116

117
    Args:
118
119
120
121
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
122
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
123
            if None is given.  See ``n_fft`` in :class:`Spectrogram`.
124
    """
125
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
126

127
    def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
128
        super(MelScale, self).__init__()
129
        self.n_mels = n_mels
130
131
132
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
133
        self.f_min = f_min
134
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
engineerchuan's avatar
engineerchuan committed
135
            n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
136
137
138
        self.fb = torch.jit.Attribute(fb, torch.Tensor)

    @torch.jit.script_method
139
140
141
    def forward(self, specgram):
        r"""
        Args:
142
            specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time)
143
144

        Returns:
145
            torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time)
146
        """
147
        if self.fb.numel() == 0:
engineerchuan's avatar
engineerchuan committed
148
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
149
150
151
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
152

153
154
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
155
156
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
        return mel_specgram
157

158

159
160
161
class MelSpectrogram(torch.jit.ScriptModule):
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
162

163
    Sources
164
165
166
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
167

168
    Args:
169
170
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        win_length (int): Window size. (Default: ``n_fft``)
171
        hop_length (int, optional): Length of hop between STFT windows. (
172
173
174
175
176
177
            Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``None``)
        pad (int): Two sided padding of signal. (Default: ``0``)
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
178
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
179
180
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
181

182
    Example
183
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
184
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
185
    """
186
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
        self.f_max = torch.jit.Attribute(f_max, Optional[float])
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
                                       pad=self.pad, window_fn=window_fn, power=2,
                                       normalized=False, wkwargs=wkwargs)
203
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
204

205
    @torch.jit.script_method
206
207
    def forward(self, waveform):
        r"""
208
        Args:
209
            waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
210
211

        Returns:
212
            torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time)
213
        """
214
215
216
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
217
218


219
class MFCC(torch.jit.ScriptModule):
220
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal
PCerles's avatar
PCerles committed
221

222
223
224
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
225

226
227
228
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
229

230
    Args:
231
232
233
234
235
236
237
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``)
        norm (str, optional): norm to use. (Default: ``'ortho'``)
        log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default:
            ``False``)
        melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
238
    """
239
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
240

241
    def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
PCerles's avatar
PCerles committed
242
                 melkwargs=None):
243
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
244
245
246
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
247
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
248
249
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
250
        self.norm = torch.jit.Attribute(norm, Optional[str])
251
        self.top_db = 80.0
252
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
253
254

        if melkwargs is not None:
255
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
256
        else:
257
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
258
259
260

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
261
262
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
        self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor)
PCerles's avatar
PCerles committed
263
264
        self.log_mels = log_mels

265
    @torch.jit.script_method
266
267
    def forward(self, waveform):
        r"""
PCerles's avatar
PCerles committed
268
        Args:
269
            waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
PCerles's avatar
PCerles committed
270
271

        Returns:
272
            torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time)
PCerles's avatar
PCerles committed
273
        """
274
        mel_specgram = self.MelSpectrogram(waveform)
275
276
        if self.log_mels:
            log_offset = 1e-6
277
            mel_specgram = torch.log(mel_specgram + log_offset)
278
        else:
279
            mel_specgram = self.amplitude_to_DB(mel_specgram)
280
281
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
282
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
283
        return mfcc
284
285


286
class MuLawEncoding(torch.jit.ScriptModule):
287
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
288
289
290
291
292
293
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
294
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
295
    """
296
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
297
298

    def __init__(self, quantization_channels=256):
299
        super(MuLawEncoding, self).__init__()
300
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
301

302
303
    @torch.jit.script_method
    def forward(self, x):
304
        r"""
David Pollack's avatar
David Pollack committed
305
        Args:
306
            x (torch.Tensor): A signal to be encoded
David Pollack's avatar
David Pollack committed
307
308

        Returns:
309
            x_mu (torch.Tensor): An encoded signal
David Pollack's avatar
David Pollack committed
310
        """
311
        return F.mu_law_encoding(x, self.quantization_channels)
312

Soumith Chintala's avatar
Soumith Chintala committed
313

314
class MuLawDecoding(torch.jit.ScriptModule):
315
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
316
317
318
319
320
321
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
322
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
323
    """
324
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
325
326

    def __init__(self, quantization_channels=256):
327
        super(MuLawDecoding, self).__init__()
328
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
329

330
331
    @torch.jit.script_method
    def forward(self, x_mu):
332
        r"""
David Pollack's avatar
David Pollack committed
333
        Args:
334
            x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
David Pollack's avatar
David Pollack committed
335
336

        Returns:
337
            torch.Tensor: The signal decoded
David Pollack's avatar
David Pollack committed
338
        """
339
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
340
341
342


class Resample(torch.nn.Module):
343
    r"""Resamples a signal from one frequency to another. A resampling method can
jamarshon's avatar
jamarshon committed
344
345
346
    be given.

    Args:
347
348
349
        orig_freq (float): The original frequency of the signal. (Default: ``16000``)
        new_freq (float): The desired frequency. (Default: ``16000``)
        resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
350
    """
351
    def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
jamarshon's avatar
jamarshon committed
352
353
354
355
356
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

357
358
    def forward(self, waveform):
        r"""
jamarshon's avatar
jamarshon committed
359
        Args:
360
            waveform (torch.Tensor): The input signal of dimension (channel, time)
jamarshon's avatar
jamarshon committed
361
362

        Returns:
363
            torch.Tensor: Output signal of dimension (channel, time)
jamarshon's avatar
jamarshon committed
364
365
        """
        if self.resampling_method == 'sinc_interpolation':
366
            return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
jamarshon's avatar
jamarshon committed
367
368

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
369
370


371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class ComplexNorm(torch.jit.ScriptModule):
    r"""Compute the norm of complex tensor input
    Args:
        power (float): Power of the norm. Defaults to `1.0`.
    """
    __constants__ = ['power']

    def __init__(self, power=1.0):
        super(ComplexNorm, self).__init__()
        self.power = power

    @torch.jit.script_method
    def forward(self, complex_tensor):
        r"""
        Args:
            complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
        Returns:
            Tensor: norm of the input tensor, shape of `(*, )`
        """
        return F.complex_norm(complex_tensor, self.power)


Vincent QB's avatar
Vincent QB committed
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
class ComputeDeltas(torch.jit.ScriptModule):
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
        win_length (int): The window length used for computing delta.
    """
    __constants__ = ['win_length']

    def __init__(self, win_length=5, mode="replicate"):
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
        self.mode = torch.jit.Attribute(mode, str)

    @torch.jit.script_method
    def forward(self, specgram):
        r"""
        Args:
Vincent QB's avatar
Vincent QB committed
412
            specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
Vincent QB's avatar
Vincent QB committed
413
414

        Returns:
Vincent QB's avatar
Vincent QB committed
415
            deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
Vincent QB's avatar
Vincent QB committed
416
417
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)