transforms.py 14.7 KB
Newer Older
1
from __future__ import division, print_function
2
from warnings import warn
3
import math
David Pollack's avatar
David Pollack committed
4
import torch
5
from typing import Optional
Jason Lian's avatar
pre  
Jason Lian committed
6
from . import functional as F
jamarshon's avatar
jamarshon committed
7
from .compliance import kaldi
8

Jason Lian's avatar
Jason Lian committed
9

10
class Spectrogram(torch.jit.ScriptModule):
11
    r"""Create a spectrogram from a audio signal
12
13

    Args:
14
15
16
17
18
19
20
21
22
23
24
        n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
        win_length (int): Window size. (Default: `n_fft`)
        hop_length (int, optional): Length of hop between STFT windows. (
            Default: `win_length // 2`)
        pad (int): Two sided padding of signal. (Default: 0)
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
        power (int): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
        normalized (bool): Whether to normalize by magnitude after stft. (Default: `False`)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
25
    """
26
    __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
27

28
29
30
    def __init__(self, n_fft=400, win_length=None, hop_length=None,
                 pad=0, window_fn=torch.hann_window,
                 power=2, normalized=False, wkwargs=None):
31
        super(Spectrogram, self).__init__()
PCerles's avatar
PCerles committed
32
        self.n_fft = n_fft
33
34
        # number of fft bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequecies due to onesided=True in torch.stft
35
36
37
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
        window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
38
        self.window = torch.jit.Attribute(window, torch.Tensor)
39
        self.pad = pad
PCerles's avatar
PCerles committed
40
        self.power = power
41
        self.normalized = normalized
42

43
    @torch.jit.script_method
44
45
    def forward(self, waveform):
        r"""
46
        Args:
47
            waveform (torch.Tensor): Tensor of audio of size (c, n)
48
49

        Returns:
50
51
52
            torch.Tensor: Channels x frequency x time (c, f, t), where channels
            is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of
            fourier bins, and time is the number of window hops (n_frames).
53
        """
54
55
        return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
                             self.win_length, self.power, self.normalized)
56
57


58
class MelScale(torch.jit.ScriptModule):
59
    r"""This turns a normal STFT into a mel frequency STFT, using a conversion
60
61
       matrix.  This uses triangular filter banks.

62
63
       User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).

64
    Args:
65
66
67
68
69
        n_mels (int): Number of mel filterbanks. (Default: 128)
        sample_rate (int): Sample rate of audio signal. (Default: 16000)
        f_min (float): Minimum frequency. (Default: 0.)
        f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`)
        n_stft (int, optional): Number of bins in STFT. Calculated from first input
70
            if `None` is given.  See `n_fft` in `Spectrogram`.
71
    """
72
    __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
73

74
    def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
75
        super(MelScale, self).__init__()
76
        self.n_mels = n_mels
77
78
79
        self.sample_rate = sample_rate
        self.f_max = f_max if f_max is not None else float(sample_rate // 2)
        assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
80
        self.f_min = f_min
81
82
83
84
85
        fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
            n_stft, self.f_min, self.f_max, self.n_mels)
        self.fb = torch.jit.Attribute(fb, torch.Tensor)

    @torch.jit.script_method
86
87
88
89
90
91
92
93
    def forward(self, specgram):
        r"""
        Args:
            specgram (torch.Tensor): a spectrogram STFT of size (c, f, t)

        Returns:
            torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
        """
94
        if self.fb.numel() == 0:
95
            tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels)
96
97
98
            # Attributes cannot be reassigned outside __init__ so workaround
            self.fb.resize_(tmp_fb.size())
            self.fb.copy_(tmp_fb)
99
100
101
102

        # (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...)
        mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
        return mel_specgram
103

104

105
class SpectrogramToDB(torch.jit.ScriptModule):
106
    r"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
107

PCerles's avatar
PCerles committed
108
109
110
111
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.

112
    Args:
113
114
        stype (str): scale of input spectrogram ('power' or 'magnitude'). The
            power being the elementwise square of the magnitude. (Default: 'power')
115
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
PCerles's avatar
PCerles committed
116
            is 80.
117
    """
118
119
    __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']

120
    def __init__(self, stype='power', top_db=None):
121
122
        super(SpectrogramToDB, self).__init__()
        self.stype = torch.jit.Attribute(stype, str)
123
        if top_db is not None and top_db < 0:
PCerles's avatar
PCerles committed
124
            raise ValueError('top_db must be positive value')
125
        self.top_db = torch.jit.Attribute(top_db, Optional[float])
126
        self.multiplier = 10.0 if stype == 'power' else 20.0
PCerles's avatar
PCerles committed
127
        self.amin = 1e-10
128
        self.ref_value = 1.0
129
        self.db_multiplier = math.log10(max(self.amin, self.ref_value))
130

131
    @torch.jit.script_method
132
133
134
135
136
137
138
139
140
141
142
    def forward(self, specgram):
        r"""Numerically stable implementation from Librosa
        https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html

        Args:
            specgram (torch.Tensor): STFT of size (c, f, t)

        Returns:
            torch.Tensor: STFT after changing scale of size (c, f, t)
        """
        return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db)
143
144


145
class MFCC(torch.jit.ScriptModule):
146
    r"""Create the Mel-frequency cepstrum coefficients from an audio signal
PCerles's avatar
PCerles committed
147

148
149
150
    By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
    This is not the textbook implementation, but is implemented here to
    give consistency with librosa.
PCerles's avatar
PCerles committed
151

152
153
154
    This output depends on the maximum value in the input spectrogram, and so
    may return different values for an audio clip split into snippets vs. a
    a full clip.
PCerles's avatar
PCerles committed
155

156
157
158
159
160
161
    Args:
        sample_rate (int): Sample rate of audio signal. (Default: 16000)
        n_mfcc (int): Number of mfc coefficients to retain
        dct_type (int): type of DCT (discrete cosine transform) to use
        norm (string, optional): norm to use
        log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
PCerles's avatar
PCerles committed
162
163
        melkwargs (dict, optional): arguments for MelSpectrogram
    """
164
    __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
165

166
    def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
PCerles's avatar
PCerles committed
167
                 melkwargs=None):
168
        super(MFCC, self).__init__()
PCerles's avatar
PCerles committed
169
170
171
        supported_dct_types = [2]
        if dct_type not in supported_dct_types:
            raise ValueError('DCT type not supported'.format(dct_type))
172
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
173
174
        self.n_mfcc = n_mfcc
        self.dct_type = dct_type
175
        self.norm = torch.jit.Attribute(norm, Optional[str])
176
177
        self.top_db = 80.0
        self.spectrogram_to_DB = SpectrogramToDB('power', self.top_db)
PCerles's avatar
PCerles committed
178
179

        if melkwargs is not None:
180
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
PCerles's avatar
PCerles committed
181
        else:
182
            self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
PCerles's avatar
PCerles committed
183
184
185

        if self.n_mfcc > self.MelSpectrogram.n_mels:
            raise ValueError('Cannot select more MFCC coefficients than # mel bins')
186
187
        dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm)
        self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor)
PCerles's avatar
PCerles committed
188
189
        self.log_mels = log_mels

190
    @torch.jit.script_method
191
192
    def forward(self, waveform):
        r"""
PCerles's avatar
PCerles committed
193
        Args:
194
            waveform (torch.Tensor): Tensor of audio of size (c, n)
PCerles's avatar
PCerles committed
195
196

        Returns:
197
            torch.Tensor: specgram_mel_db of size (c, `n_mfcc`, t)
PCerles's avatar
PCerles committed
198
        """
199
        mel_specgram = self.MelSpectrogram(waveform)
200
201
        if self.log_mels:
            log_offset = 1e-6
202
            mel_specgram = torch.log(mel_specgram + log_offset)
203
        else:
204
205
206
            mel_specgram = self.spectrogram_to_DB(mel_specgram)
        # (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...)
        mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
207
        return mfcc
208
209


210
class MelSpectrogram(torch.jit.ScriptModule):
211
212
    r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
    and MelScale.
213
214
215
216
217
218
219

    Sources:
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html

    Args:
220
221
222
223
224
225
226
227
228
229
230
231
        sample_rate (int): Sample rate of audio signal. (Default: 16000)
        win_length (int): Window size. (Default: `n_fft`)
        hop_length (int, optional): Length of hop between STFT windows. (
            Default: `win_length // 2`)
        n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
        f_min (float): Minimum frequency. (Default: 0.)
        f_max (float, optional): Maximum frequency. (Default: `None`)
        pad (int): Two sided padding of signal. (Default: 0)
        n_mels (int): Number of mel filterbanks. (Default: 128)
        window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
            that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
        wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
232
233

    Example:
234
235
        >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
        >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform)  # (c, n_mels, t)
236
    """
237
    __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
238

239
240
    def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
                 pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
241
        super(MelSpectrogram, self).__init__()
242
        self.sample_rate = sample_rate
PCerles's avatar
PCerles committed
243
        self.n_fft = n_fft
244
245
        self.win_length = win_length if win_length is not None else n_fft
        self.hop_length = hop_length if hop_length is not None else self.win_length // 2
246
247
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
248
        self.f_max = torch.jit.Attribute(f_max, Optional[float])
249
        self.f_min = f_min
250
251
252
253
254
        self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
                                       hop_length=self.hop_length,
                                       pad=self.pad, window_fn=window_fn, power=2,
                                       normalized=False, wkwargs=wkwargs)
        self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max)
255

256
    @torch.jit.script_method
257
258
    def forward(self, waveform):
        r"""
259
        Args:
260
            waveform (torch.Tensor): Tensor of audio of size (c, n)
261
262

        Returns:
263
            torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
264
        """
265
266
267
        specgram = self.spectrogram(waveform)
        mel_specgram = self.mel_scale(specgram)
        return mel_specgram
268

Soumith Chintala's avatar
Soumith Chintala committed
269

270
class MuLawEncoding(torch.jit.ScriptModule):
271
    r"""Encode signal based on mu-law companding.  For more info see the
David Pollack's avatar
David Pollack committed
272
273
274
275
276
277
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
278
        quantization_channels (int): Number of channels (Default: 256)
David Pollack's avatar
David Pollack committed
279
    """
280
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
281
282

    def __init__(self, quantization_channels=256):
283
        super(MuLawEncoding, self).__init__()
284
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
285

286
287
    @torch.jit.script_method
    def forward(self, x):
288
        r"""
David Pollack's avatar
David Pollack committed
289
        Args:
290
            x (torch.Tensor): A signal to be encoded
David Pollack's avatar
David Pollack committed
291
292

        Returns:
293
            x_mu (torch.Tensor): An encoded signal
David Pollack's avatar
David Pollack committed
294
        """
295
        return F.mu_law_encoding(x, self.quantization_channels)
296

Soumith Chintala's avatar
Soumith Chintala committed
297

298
class MuLawDecoding(torch.jit.ScriptModule):
299
    r"""Decode mu-law encoded signal.  For more info see the
David Pollack's avatar
David Pollack committed
300
301
302
303
304
305
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
306
        quantization_channels (int): Number of channels (Default: 256)
David Pollack's avatar
David Pollack committed
307
    """
308
    __constants__ = ['quantization_channels']
David Pollack's avatar
David Pollack committed
309
310

    def __init__(self, quantization_channels=256):
311
        super(MuLawDecoding, self).__init__()
312
        self.quantization_channels = quantization_channels
David Pollack's avatar
David Pollack committed
313

314
315
    @torch.jit.script_method
    def forward(self, x_mu):
316
        r"""
David Pollack's avatar
David Pollack committed
317
        Args:
318
            x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
David Pollack's avatar
David Pollack committed
319
320

        Returns:
321
            torch.Tensor: The signal decoded
David Pollack's avatar
David Pollack committed
322
        """
323
        return F.mu_law_decoding(x_mu, self.quantization_channels)
jamarshon's avatar
jamarshon committed
324
325
326


class Resample(torch.nn.Module):
327
    r"""Resamples a signal from one frequency to another. A resampling method can
jamarshon's avatar
jamarshon committed
328
329
330
    be given.

    Args:
331
332
333
        orig_freq (float): The original frequency of the signal
        new_freq (float): The desired frequency
        resampling_method (str): The resampling method (Default: 'sinc_interpolation')
jamarshon's avatar
jamarshon committed
334
335
336
337
338
339
340
    """
    def __init__(self, orig_freq, new_freq, resampling_method='sinc_interpolation'):
        super(Resample, self).__init__()
        self.orig_freq = orig_freq
        self.new_freq = new_freq
        self.resampling_method = resampling_method

341
342
    def forward(self, waveform):
        r"""
jamarshon's avatar
jamarshon committed
343
        Args:
344
            waveform (torch.Tensor): The input signal of size (c, n)
jamarshon's avatar
jamarshon committed
345
346

        Returns:
347
            torch.Tensor: Output signal of size (c, m)
jamarshon's avatar
jamarshon committed
348
349
        """
        if self.resampling_method == 'sinc_interpolation':
350
            return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
jamarshon's avatar
jamarshon committed
351
352

        raise ValueError('Invalid resampling method: %s' % (self.resampling_method))