transforms.py 16.2 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
2
from warnings import warn
3
import math
David Pollack's avatar
David Pollack committed
4
import torch
5
from typing import Optional
Jason Lian's avatar
pre  
Jason Lian committed
6
from . import functional as F
jamarshon's avatar
jamarshon committed
7
from .compliance import kaldi
8

Jason Lian's avatar
Jason Lian committed
9

10
11
__all__ = [
    'Spectrogram',
12
    'AmplitudeToDB',
13
14
15
16
17
18
19
20
21
    'MelScale',
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
]


22
class Spectrogram(torch.jit.ScriptModule):
23
    r"""Create a spectrogram from a audio signal
24
25

    Args:
26
27
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        win_length (int): Window size. (Default: ``n_fft``)
28
        hop_length (int, optional): Length of hop between STFT windows. (
29
30
            Default: ``win_length // 2``)
        pad (int): Two sided padding of signal. (Default: ``0``)
31
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
32
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
33
        power (int): Exponent for the magnitude spectrogram,
34
35
36
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
37
    """
38
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
39

40
41
42
    def __init__(self, n_fft=400, win_length=None, hop_length=None,
                 pad=0, window_fn=torch.hann_window,
                 power=2, normalized=False, wkwargs=None):
43
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
44
        self.n_fft = n_fft
45
        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
46
        # number of frequecies due to onesided=True in torch.stft
47
48
49
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
50
        self.window = torch.jit.Attribute(window, torch.Tensor)
51
        self.pad = pad
PCerles's avatar
PCerles committed
52
        self.power = power
53
        self.normalized = normalized
54

55
    @torch.jit.script_method
56
57
    def forward(self, waveform):
        r"""
58
        Args:
59
            waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
60
61

        Returns:
62
63
64
            torch.Tensor: Dimension (channel, freq, time), where channel
            is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of
            Fourier bins, and time is the number of window hops (n_frames).
65
        """
66
67
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
68
69


70
71
class AmplitudeToDB(torch.jit.ScriptModule):
    r"""Turns a tensor from the power/amplitude scale to the decibel scale.
72

73
    This output depends on the maximum value in the input tensor, and so
74
75
76
77
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
78
        stype (str): scale of input tensor ('power' or 'magnitude'). The
79
            power being the elementwise square of the magnitude. (Default: ``'power'``)
80
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
81
            is 80. (Default: ``None``)
82
83
84
85
    """
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

    def __init__(self, stype='power', top_db=None):
86
        super(AmplitudeToDB, self).__init__()
87
88
89
90
91
92
93
94
95
96
        self.stype = torch.jit.Attribute(stype, str)
        if top_db is not None and top_db < 0:
            raise ValueError('top_db must be positive value')
        self.top_db = torch.jit.Attribute(top_db, Optional[float])
        self.multiplier = 10.0 if stype == 'power' else 20.0
        self.amin = 1e-10
        self.ref_value = 1.0
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))

    @torch.jit.script_method
97
    def forward(self, x):
98
99
100
101
        r"""Numerically stable implementation from Librosa
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
102
            x (torch.Tensor): Input tensor before being converted to decibel scale
103
104

        Returns:
105
            torch.Tensor: Output tensor in decibel scale
106
        """
107
        return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
108
109


110
class MelScale(torch.jit.ScriptModule):
111
    r"""This turns a normal STFT into a mel frequency STFT, using a conversion
112
    matrix.  This uses triangular filter banks.
113

114
    User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
115

116
    Args:
117
118
119
120
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
121
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
122
            if None is given.  See ``n_fft`` in :class:`Spectrogram`.
123
    """
124
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
125

126
    def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
127
        super(MelScale, self).__init__()
128
        self.n_mels = n_mels
129
130
131
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
132
        self.f_min = f_min
133
134
135
136
137
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
            n_stft, self.f_min, self.f_max, self.n_mels)
        self.fb = torch.jit.Attribute(fb, torch.Tensor)

    @torch.jit.script_method
138
139
140
    def forward(self, specgram):
        r"""
        Args:
141
            specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time)
142
143

        Returns:
144
            torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time)
145
        """
146
        if self.fb.numel() == 0:
147
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels)
148
149
150
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
151

152
153
        # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
        # -> (channel, time, n_mels).transpose(...)
154
155
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
        return mel_specgram
156

157

158
159
160
class MelSpectrogram(torch.jit.ScriptModule):
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
161

162
    Sources
163
164
165
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
PCerles's avatar
PCerles committed
166

167
    Args:
168
169
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        win_length (int): Window size. (Default: ``n_fft``)
170
        hop_length (int, optional): Length of hop between STFT windows. (
171
172
173
174
175
176
            Default: ``win_length // 2``)
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
        f_min (float): Minimum frequency. (Default: ``0.``)
        f_max (float, optional): Maximum frequency. (Default: ``None``)
        pad (int): Two sided padding of signal. (Default: ``0``)
        n_mels (int): Number of mel filterbanks. (Default: ``128``)
177
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
178
179
            that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
180

181
    Example
182
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
183
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (channel, n_mels, time)
184
    """
185
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
186

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
        super(MelSpectrogram, self).__init__()
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
        self.f_max = torch.jit.Attribute(f_max, Optional[float])
        self.f_min = f_min
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
                                       pad=self.pad, window_fn=window_fn, power=2,
                                       normalized=False, wkwargs=wkwargs)
202
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
203

204
    @torch.jit.script_method
205
206
    def forward(self, waveform):
        r"""
207
        Args:
208
            waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
209
210

        Returns:
211
            torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time)
212
        """
213
214
215
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
216
217


218
class MFCC(torch.jit.ScriptModule):
219
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal
PCerles's avatar
PCerles committed
220

221
222
223
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
224

225
226
227
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
228

229
    Args:
230
231
232
233
234
235
236
        sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
        n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``)
        dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``)
        norm (str, optional): norm to use. (Default: ``'ortho'``)
        log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default:
            ``False``)
        melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
PCerles's avatar
PCerles committed
237
    """
238
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
239

240
    def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
PCerles's avatar
PCerles committed
241
                 melkwargs=None):
242
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
243
244
245
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
246
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
247
248
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
249
        self.norm = torch.jit.Attribute(norm, Optional[str])
250
        self.top_db = 80.0
251
        self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
PCerles's avatar
PCerles committed
252
253

        if melkwargs is not None:
254
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
255
        else:
256
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
257
258
259

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
260
261
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
        self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor)
PCerles's avatar
PCerles committed
262
263
        self.log_mels = log_mels

264
    @torch.jit.script_method
265
266
    def forward(self, waveform):
        r"""
PCerles's avatar
PCerles committed
267
        Args:
268
            waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
PCerles's avatar
PCerles committed
269
270

        Returns:
271
            torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time)
PCerles's avatar
PCerles committed
272
        """
273
        mel_specgram = self.MelSpectrogram(waveform)
274
275
        if self.log_mels:
            log_offset = 1e-6
276
            mel_specgram = torch.log(mel_specgram + log_offset)
277
        else:
278
            mel_specgram = self.amplitude_to_DB(mel_specgram)
279
280
        # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
        # -> (channel, time, n_mfcc).tranpose(...)
281
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
282
        return mfcc
283
284


285
class MuLawEncoding(torch.jit.ScriptModule):
286
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
287
288
289
290
291
292
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
293
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
294
    """
295
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
296
297

    def __init__(self, quantization_channels=256):
298
        super(MuLawEncoding, self).__init__()
299
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
300

301
302
    @torch.jit.script_method
    def forward(self, x):
303
        r"""
David Pollack's avatar
David Pollack committed
304
        Args:
305
            x (torch.Tensor): A signal to be encoded
David Pollack's avatar
David Pollack committed
306
307

        Returns:
308
            x_mu (torch.Tensor): An encoded signal
David Pollack's avatar
David Pollack committed
309
        """
310
        return F.mu_law_encoding(x, self.quantization_channels)
311

Soumith Chintala's avatar
Soumith Chintala committed
312

313
class MuLawDecoding(torch.jit.ScriptModule):
314
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
315
316
317
318
319
320
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
321
        quantization_channels (int): Number of channels (Default: ``256``)
David Pollack's avatar
David Pollack committed
322
    """
323
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
324
325

    def __init__(self, quantization_channels=256):
326
        super(MuLawDecoding, self).__init__()
327
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
328

329
330
    @torch.jit.script_method
    def forward(self, x_mu):
331
        r"""
David Pollack's avatar
David Pollack committed
332
        Args:
333
            x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
David Pollack's avatar
David Pollack committed
334
335

        Returns:
336
            torch.Tensor: The signal decoded
David Pollack's avatar
David Pollack committed
337
        """
338
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
339
340
341


class Resample(torch.nn.Module):
342
    r"""Resamples a signal from one frequency to another. A resampling method can
jamarshon's avatar
jamarshon committed
343
344
345
    be given.

    Args:
346
347
348
        orig_freq (float): The original frequency of the signal. (Default: ``16000``)
        new_freq (float): The desired frequency. (Default: ``16000``)
        resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``)
jamarshon's avatar
jamarshon committed
349
    """
350
    def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
jamarshon's avatar
jamarshon committed
351
352
353
354
355
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

356
357
    def forward(self, waveform):
        r"""
jamarshon's avatar
jamarshon committed
358
        Args:
359
            waveform (torch.Tensor): The input signal of dimension (channel, time)
jamarshon's avatar
jamarshon committed
360
361

        Returns:
362
            torch.Tensor: Output signal of dimension (channel, time)
jamarshon's avatar
jamarshon committed
363
364
        """
        if self.resampling_method == 'sinc_interpolation':
365
            return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
jamarshon's avatar
jamarshon committed
366
367

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Vincent QB's avatar
Vincent QB committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394


class ComputeDeltas(torch.jit.ScriptModule):
    r"""Compute delta coefficients of a tensor, usually a spectrogram.

    See `torchaudio.functional.compute_deltas` for more details.

    Args:
        win_length (int): The window length used for computing delta.
    """
    __constants__ = ['win_length']

    def __init__(self, win_length=5, mode="replicate"):
        super(ComputeDeltas, self).__init__()
        self.win_length = win_length
        self.mode = torch.jit.Attribute(mode, str)

    @torch.jit.script_method
    def forward(self, specgram):
        r"""
        Args:
            specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)

        Returns:
            deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
        """
        return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)