transforms.py 15.5 KB
Newer Older
1
from __future__ import division, print_function
2
from warnings import warn
David Pollack's avatar
David Pollack committed
3
4
import torch
import numpy as np
Soumith Chintala's avatar
Soumith Chintala committed
5

6

David Pollack's avatar
David Pollack committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.Scale(),
        >>>     transforms.PadTrim(max_len=16000),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, audio):
        for t in self.transforms:
            audio = t(audio)
        return audio

28
29
30
31
32
33
34
35
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

Soumith Chintala's avatar
Soumith Chintala committed
36

David Pollack's avatar
David Pollack committed
37
38
39
40
41
42
class Scale(object):
    """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
    to a floating point number between -1.0 and 1.0.  Note the 16-bit number is
    called the "bit depth" or "precision", not to be confused with "bit rate".

    Args:
David Pollack's avatar
David Pollack committed
43
        factor (int): maximum value of input tensor. default: 16-bit depth
David Pollack's avatar
David Pollack committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    """

    def __init__(self, factor=2**31):
        self.factor = factor

    def __call__(self, tensor):
        """

        Args:
            tensor (Tensor): Tensor of audio of size (Samples x Channels)

        Returns:
            Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)

        """
60
        if not tensor.dtype.is_floating_point:
61
            tensor = tensor.to(torch.float32)
David Pollack's avatar
David Pollack committed
62
63
64

        return tensor / self.factor

65
66
67
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
68

David Pollack's avatar
David Pollack committed
69
70
71
class PadTrim(object):
    """Pad/Trim a 1d-Tensor (Signal or Labels)

David Pollack's avatar
David Pollack committed
72
    Args:
73
        tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
David Pollack's avatar
David Pollack committed
74
        max_len (int): Length to which the tensor will be padded
75
        channels_first (bool): Pad for channels first tensors.  Default: `True`
David Pollack's avatar
David Pollack committed
76

David Pollack's avatar
David Pollack committed
77
78
    """

79
    def __init__(self, max_len, fill_value=0, channels_first=True):
David Pollack's avatar
David Pollack committed
80
81
        self.max_len = max_len
        self.fill_value = fill_value
82
        self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
David Pollack's avatar
David Pollack committed
83
84
85
86
87

    def __call__(self, tensor):
        """

        Returns:
88
            Tensor: (c x n) or (n x c)
David Pollack's avatar
David Pollack committed
89
90

        """
91
        assert tensor.size(self.ch_dim) < 128, \
92
            "Too many channels ({}) detected, see channels_first param.".format(tensor.size(self.ch_dim))
93
        if self.max_len > tensor.size(self.len_dim):
94
95
96
97
98
99
            padding = [self.max_len - tensor.size(self.len_dim)
                       if (i % 2 == 1) and (i // 2 != self.len_dim)
                       else 0
                       for i in range(4)]
            with torch.no_grad():
                tensor = torch.nn.functional.pad(tensor, padding, "constant", self.fill_value)
100
101
        elif self.max_len < tensor.size(self.len_dim):
            tensor = tensor.narrow(self.len_dim, 0, self.max_len)
David Pollack's avatar
David Pollack committed
102
103
        return tensor

104
105
106
    def __repr__(self):
        return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)

David Pollack's avatar
David Pollack committed
107
108

class DownmixMono(object):
109
110
    """Downmix any stereo signals to mono.  Consider using a `SoxEffectsChain` with
       the `channels` effect instead of this transformation.
David Pollack's avatar
David Pollack committed
111

David Pollack's avatar
David Pollack committed
112
    Inputs:
113
114
        tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
        channels_first (bool): Downmix across channels dimension.  Default: `True`
David Pollack's avatar
David Pollack committed
115
116
117
118

    Returns:
        tensor (Tensor) (Samples x 1):

David Pollack's avatar
David Pollack committed
119
120
    """

121
122
    def __init__(self, channels_first=None):
        self.ch_dim = int(not channels_first)
David Pollack's avatar
David Pollack committed
123
124

    def __call__(self, tensor):
125
        if not tensor.dtype.is_floating_point:
126
            tensor = tensor.to(torch.float32)
David Pollack's avatar
David Pollack committed
127

128
        tensor = torch.mean(tensor, self.ch_dim, True)
David Pollack's avatar
David Pollack committed
129
        return tensor
130

131
132
133
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
134

135
class LC2CL(object):
136
    """Permute a 2d tensor from samples (n x c) to (c x n)
137
138
139
140
141
142
    """

    def __call__(self, tensor):
        """

        Args:
143
            tensor (Tensor): Tensor of audio signal with shape (LxC)
144
145

        Returns:
146
            tensor (Tensor): Tensor of audio signal with shape (CxL)
147
148
149
150

        """
        return tensor.transpose(0, 1).contiguous()

151
152
153
    def __repr__(self):
        return self.__class__.__name__ + '()'

154

155
156
157
158
159
def SPECTROGRAM(*args, **kwargs):
    warn("SPECTROGRAM has been renamed to Spectrogram")
    return Spectrogram(*args, **kwargs)


160
class Spectrogram(object):
161
162
163
164
    """Create a spectrogram from a raw audio signal

    Args:
        sr (int): sample rate of audio signal
165
        ws (int): window size
166
        hop (int, optional): length of hop between STFT windows. default: ws // 2
167
        n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins. default: ws
168
169
170
171
172
        pad (int): two sided padding of signal
        window (torch windowing function): default: torch.hann_window
        wkwargs (dict, optional): arguments for window function

    """
173
    def __init__(self, ws=400, hop=None, n_fft=None,
174
175
                 pad=0, window=torch.hann_window, wkwargs=None):
        self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
176
177
        self.ws = ws
        self.hop = hop if hop is not None else ws // 2
178
179
        # number of fft bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequecies due to onesided=True in torch.stft
180
        self.n_fft = n_fft if n_fft is not None else ws
181
182
183
184
185
186
        self.pad = pad
        self.wkwargs = wkwargs

    def __call__(self, sig):
        """
        Args:
187
            sig (Tensor): Tensor of audio of size (c, n)
188
189

        Returns:
190
            spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels
191
192
193
194
195
196
197
                is unchanged, hops is the number of hops, and n_fft is the
                number of fourier bins, which should be the window size divided
                by 2 plus 1.

        """
        assert sig.dim() == 2

198
        if self.pad > 0:
199
200
            with torch.no_grad():
                sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
201
        self.window = self.window.to(sig.device)
202
203
204
        spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
                            self.window, center=False,
                            normalized=True, onesided=True).transpose(1, 2)
205
206
        spec_f /= self.window.pow(2).sum().sqrt()
        spec_f = spec_f.pow(2).sum(-1)  # get power of "complex" tensor (c, l, n_fft)
207
        return spec_f
208
209


210
211
212
213
214
def F2M(*args, **kwargs):
    warn("F2M has been renamed to MelScale")
    return MelScale(*args, **kwargs)


215
216
class MelScale(object):
    """This turns a normal STFT into a mel frequency STFT, using a conversion
217
218
219
       matrix.  This uses triangular filter banks.

    Args:
220
        n_mels (int): number of mel bins
221
        sr (int): sample rate of audio signal
222
        f_max (float, optional): maximum frequency. default: `sr` // 2
223
        f_min (float): minimum frequency. default: 0
224
        n_stft (int, optional): number of filter banks from stft. Calculated from first input
225
            if `None` is given.  See `n_fft` in `Spectrogram`.
226
    """
227
    def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None):
228
229
230
231
        self.n_mels = n_mels
        self.sr = sr
        self.f_max = f_max if f_max is not None else sr // 2
        self.f_min = f_min
232
        self.fb = self._create_fb_matrix(n_stft) if n_stft is not None else n_stft
233
234

    def __call__(self, spec_f):
235
        if self.fb is None:
236
            self.fb = self._create_fb_matrix(spec_f.size(2)).to(spec_f.device)
237
238
239
        spec_m = torch.matmul(spec_f, self.fb)  # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
        return spec_m

240
    def _create_fb_matrix(self, n_stft):
241
        """ Create a frequency bin conversion matrix.
242

243
        Args:
244
            n_stft (int): number of filter banks from spectrogram
245
        """
246

247
248
249
250
251
        # get stft freq bins
        stft_freqs = torch.linspace(self.f_min, self.f_max, n_stft)
        # calculate mel freq bins
        m_min = 0. if self.f_min == 0 else self._hertz_to_mel(self.f_min)
        m_max = self._hertz_to_mel(self.f_max)
252
        m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
253
254
255
256
257
258
259
260
261
262
        f_pts = self._mel_to_hertz(m_pts)
        # calculate the difference between each mel point and each stft freq point in hertz
        f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
        slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1)  # (n_stft, n_mels + 2)
        # create overlapping triangles
        z = torch.tensor(0.)
        down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1]  # (n_stft, n_mels)
        up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_stft, n_mels)
        fb = torch.max(z, torch.min(down_slopes, up_slopes))
        return fb
263

264
265
    def _hertz_to_mel(self, f):
        return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
266

267
268
    def _mel_to_hertz(self, mel):
        return 700. * (10**(mel / 2595.) - 1.)
269
270


271
272
273
274
275
def SPEC2DB(*args, **kwargs):
    warn("SPEC2DB has been renamed to SpectogramToDB, please update your program")
    return SpectogramToDB(*args, **kwargs)


276
class SpectogramToDB(object):
277
278
279
280
281
282
283
284
285
286
    """Turns a spectrogram from the power/amplitude scale to the decibel scale.

    Args:
        stype (str): scale of input spectrogram ("power" or "magnitude").  The
            power being the elementwise square of the magnitude. default: "power"
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
            is -80.
    """
    def __init__(self, stype="power", top_db=None):
        self.stype = stype
287
288
289
        if top_db is not None and top_db > 0:
            top_db = -top_db
        self.top_db = top_db
290
291
292
293
        self.multiplier = 10. if stype == "power" else 20.

    def __call__(self, spec):

294
        spec_db = self.multiplier * torch.log10(spec / spec.max())  # power -> dB
295
        if self.top_db is not None:
296
            spec_db = torch.max(spec_db, spec_db.new_full((1,), self.top_db))
297
        return spec_db
298
299


300
301
302
303
304
def MEL2(*args, **kwargs):
    warn("MEL2 has been renamed to MelSpectrogram")
    return MelSpectrogram(*args, **kwargs)


305
class MelSpectrogram(object):
306
    """Create MEL Spectrograms from a raw audio signal using the stft
307
       function in PyTorch.
308
309
310
311
312
313
314
315

    Sources:
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html

    Args:
        sr (int): sample rate of audio signal
316
317
318
        ws (int): window size
        hop (int, optional): length of hop between STFT windows. default: `ws` // 2
        n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1
319
320
        f_max (float, optional): maximum frequency. default: `sr` // 2
        f_min (float): minimum frequency. default: 0
321
322
        pad (int): two sided padding of signal
        n_mels (int): number of MEL bins
323
        window (torch windowing function): default: `torch.hann_window`
324
325
326
327
        wkwargs (dict, optional): arguments for window function

    Example:
        >>> sig, sr = torchaudio.load("test.wav", normalization=True)
328
        >>> spec_mel = transforms.MelSpectrogram(sr)(sig)  # (c, l, m)
329
    """
330
    def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, f_min=0., f_max=None,
331
332
                 pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
        self.window = window
333
334
335
336
337
338
339
340
        self.sr = sr
        self.ws = ws
        self.hop = hop if hop is not None else ws // 2
        self.n_fft = n_fft  # number of fourier bins (ws // 2 + 1 by default)
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
        self.wkwargs = wkwargs
        self.top_db = -80.
341
342
343
        self.f_max = f_max
        self.f_min = f_min
        self.spec = Spectrogram(self.ws, self.hop, self.n_fft,
344
                                self.pad, self.window, self.wkwargs)
345
346
        self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min)
        self.s2db = SpectogramToDB("power", self.top_db)
347
348
349
        self.transforms = Compose([
            self.spec, self.fm, self.s2db,
        ])
350
351
352
353
354
355
356
357
358
359
360
361

    def __call__(self, sig):
        """
        Args:
            sig (Tensor): Tensor of audio of size (channels [c], samples [n])

        Returns:
            spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
                is unchanged, hops is the number of hops, and n_mels is the
                number of mel bins.

        """
362
        spec_mel_db = self.transforms(sig)
363

364
        return spec_mel_db
365

366
367
368
369
370

def MEL(*args, **kwargs):
    raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa")


371
class BLC2CBL(object):
372
373
    """Permute a 3d tensor from Bands x Sample length x Channels to Channels x
       Bands x Samples length
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    """

    def __call__(self, tensor):
        """

        Args:
            tensor (Tensor): Tensor of spectrogram with shape (BxLxC)

        Returns:
            tensor (Tensor): Tensor of spectrogram with shape (CxBxL)

        """

        return tensor.permute(2, 0, 1).contiguous()
David Pollack's avatar
David Pollack committed
388

389
390
391
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
392

David Pollack's avatar
David Pollack committed
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class MuLawEncoding(object):
    """Encode signal based on mu-law companding.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
        quantization_channels (int): Number of channels. default: 256

    """

    def __init__(self, quantization_channels=256):
        self.qc = quantization_channels

    def __call__(self, x):
        """

        Args:
            x (FloatTensor/LongTensor or ndarray)

        Returns:
            x_mu (LongTensor or ndarray)

        """
        mu = self.qc - 1.
        if isinstance(x, np.ndarray):
            x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
            x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
422
        elif isinstance(x, torch.Tensor):
423
            if not x.dtype.is_floating_point:
424
425
                x = x.to(torch.float)
            mu = torch.tensor(mu, dtype=x.dtype)
Soumith Chintala's avatar
Soumith Chintala committed
426
427
            x_mu = torch.sign(x) * torch.log1p(mu *
                                               torch.abs(x)) / torch.log1p(mu)
David Pollack's avatar
David Pollack committed
428
429
430
            x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
        return x_mu

431
432
433
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
434

David Pollack's avatar
David Pollack committed
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
class MuLawExpanding(object):
    """Decode mu-law encoded signal.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
        quantization_channels (int): Number of channels. default: 256

    """

    def __init__(self, quantization_channels=256):
        self.qc = quantization_channels

    def __call__(self, x_mu):
        """

        Args:
            x_mu (FloatTensor/LongTensor or ndarray)

        Returns:
            x (FloatTensor or ndarray)

        """
        mu = self.qc - 1.
        if isinstance(x_mu, np.ndarray):
            x = ((x_mu) / mu) * 2 - 1.
            x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
464
        elif isinstance(x_mu, torch.Tensor):
465
            if not x_mu.dtype.is_floating_point:
466
467
                x_mu = x_mu.to(torch.float)
            mu = torch.tensor(mu, dtype=x_mu.dtype)
David Pollack's avatar
David Pollack committed
468
469
470
            x = ((x_mu) / mu) * 2 - 1.
            x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
        return x
471
472
473

    def __repr__(self):
        return self.__class__.__name__ + '()'