transforms.py 15.9 KB
Newer Older
1
from __future__ import division, print_function
David Pollack's avatar
David Pollack committed
2
import torch
3
from torch.autograd import Variable
David Pollack's avatar
David Pollack committed
4
import numpy as np
5
6
7
8
try:
    import librosa
except ImportError:
    librosa = None
David Pollack's avatar
David Pollack committed
9

Soumith Chintala's avatar
Soumith Chintala committed
10

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def _check_is_variable(tensor):
    if isinstance(tensor, torch.Tensor):
        is_variable = False
        tensor = Variable(tensor, requires_grad=False)
    elif isinstance(tensor, Variable):
        is_variable = True
    else:
        raise TypeError("tensor should be a Variable or Tensor, but is {}".format(type(tensor)))

    return tensor, is_variable


def _tlog10(x):
    """Pytorch Log10
    """
    return torch.log(x) / torch.log(x.new([10]))


David Pollack's avatar
David Pollack committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.Scale(),
        >>>     transforms.PadTrim(max_len=16000),
        >>> ])
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, audio):
        for t in self.transforms:
            audio = t(audio)
        return audio

50
51
52
53
54
55
56
57
    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

Soumith Chintala's avatar
Soumith Chintala committed
58

David Pollack's avatar
David Pollack committed
59
60
61
62
63
64
class Scale(object):
    """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
    to a floating point number between -1.0 and 1.0.  Note the 16-bit number is
    called the "bit depth" or "precision", not to be confused with "bit rate".

    Args:
David Pollack's avatar
David Pollack committed
65
        factor (int): maximum value of input tensor. default: 16-bit depth
David Pollack's avatar
David Pollack committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    """

    def __init__(self, factor=2**31):
        self.factor = factor

    def __call__(self, tensor):
        """

        Args:
            tensor (Tensor): Tensor of audio of size (Samples x Channels)

        Returns:
            Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)

        """
        if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
            tensor = tensor.float()

        return tensor / self.factor

87
88
89
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
90

David Pollack's avatar
David Pollack committed
91
92
93
class PadTrim(object):
    """Pad/Trim a 1d-Tensor (Signal or Labels)

David Pollack's avatar
David Pollack committed
94
95
96
97
    Args:
        tensor (Tensor): Tensor of audio of size (Samples x Channels)
        max_len (int): Length to which the tensor will be padded

David Pollack's avatar
David Pollack committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    """

    def __init__(self, max_len, fill_value=0):
        self.max_len = max_len
        self.fill_value = fill_value

    def __call__(self, tensor):
        """

        Returns:
            Tensor: (max_len x Channels)

        """
        if self.max_len > tensor.size(0):
Soumith Chintala's avatar
Soumith Chintala committed
112
            pad = torch.ones((self.max_len - tensor.size(0),
David Pollack's avatar
David Pollack committed
113
114
115
116
117
118
119
                              tensor.size(1))) * self.fill_value
            pad = pad.type_as(tensor)
            tensor = torch.cat((tensor, pad), dim=0)
        elif self.max_len < tensor.size(0):
            tensor = tensor[:self.max_len, :]
        return tensor

120
121
122
    def __repr__(self):
        return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)

David Pollack's avatar
David Pollack committed
123
124
125
126

class DownmixMono(object):
    """Downmix any stereo signals to mono

David Pollack's avatar
David Pollack committed
127
128
129
130
131
132
    Inputs:
        tensor (Tensor): Tensor of audio of size (Samples x Channels)

    Returns:
        tensor (Tensor) (Samples x 1):

David Pollack's avatar
David Pollack committed
133
134
135
136
137
138
139
140
141
142
143
144
    """

    def __init__(self):
        pass

    def __call__(self, tensor):
        if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
            tensor = tensor.float()

        if tensor.size(1) > 1:
            tensor = torch.mean(tensor.float(), 1, True)
        return tensor
145

146
147
148
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
149

150
151
152
153
154
155
156
157
158
class LC2CL(object):
    """Permute a 2d tensor from samples (Length) x Channels to Channels x
       samples (Length)
    """

    def __call__(self, tensor):
        """

        Args:
159
            tensor (Tensor): Tensor of audio signal with shape (LxC)
160
161

        Returns:
162
            tensor (Tensor): Tensor of audio signal with shape (CxL)
163
164
165
166
167

        """

        return tensor.transpose(0, 1).contiguous()

168
169
170
    def __repr__(self):
        return self.__class__.__name__ + '()'

171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class SPECTROGRAM(object):
    """Create a spectrogram from a raw audio signal

    Args:
        sr (int): sample rate of audio signal
        ws (int): window size, often called the fft size as well
        hop (int, optional): length of hop between STFT windows. default: ws // 2
        n_fft (int, optional): number of fft bins. default: ws // 2 + 1
        pad (int): two sided padding of signal
        window (torch windowing function): default: torch.hann_window
        wkwargs (dict, optional): arguments for window function

    """
    def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
                 pad=0, window=torch.hann_window, wkwargs=None):
        if isinstance(window, Variable):
            self.window = window
        else:
            self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
            self.window = Variable(self.window, volatile=True)
        self.sr = sr
        self.ws = ws
        self.hop = hop if hop is not None else ws // 2
195
196
197
        # number of fft bins. the returned STFT result will have n_fft // 2 + 1
        # number of frequecies due to onesided=True in torch.stft
        self.n_fft = (n_fft - 1) * 2 if n_fft is not None else ws
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        self.pad = pad
        self.wkwargs = wkwargs

    def __call__(self, sig):
        """
        Args:
            sig (Tensor or Variable): Tensor of audio of size (c, n)

        Returns:
            spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels
                is unchanged, hops is the number of hops, and n_fft is the
                number of fourier bins, which should be the window size divided
                by 2 plus 1.

        """
        sig, is_variable = _check_is_variable(sig)

        assert sig.dim() == 2

217
218
219
220
221
222
223
224
225
226
227
        if self.pad > 0:
            c, n = sig.size()
            new_sig = sig.new_empty(c, n + self.pad * 2)
            new_sig[:, :self.pad].zero_()
            new_sig[:, -self.pad:].zero_()
            new_sig.narrow(1, self.pad, n).copy_(sig)
            sig = new_sig

        spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
                            self.window, center=False,
                            normalized=True, onesided=True).transpose(1, 2)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        spec_f /= self.window.pow(2).sum().sqrt()
        spec_f = spec_f.pow(2).sum(-1)  # get power of "complex" tensor (c, l, n_fft)
        return spec_f if is_variable else spec_f.data


class F2M(object):
    """This turns a normal STFT into a MEL Frequency STFT, using a conversion
       matrix.  This uses triangular filter banks.

    Args:
        n_mels (int): number of MEL bins
        sr (int): sample rate of audio signal
        f_max (float, optional): maximum frequency. default: sr // 2
        f_min (float): minimum frequency. default: 0
    """
    def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0.):
        self.n_mels = n_mels
        self.sr = sr
        self.f_max = f_max if f_max is not None else sr // 2
        self.f_min = f_min

    def __call__(self, spec_f):

        spec_f, is_variable = _check_is_variable(spec_f)
        n_fft = spec_f.size(2)

        m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
        m_max = 2595 * np.log10(1. + (self.f_max / 700))

        m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
        f_pts = (700 * (10**(m_pts / 2595) - 1))

        bins = torch.floor(((n_fft - 1) * 2) * f_pts / self.sr).long()

        fb = torch.zeros(n_fft, self.n_mels)
        for m in range(1, self.n_mels + 1):
Soumith Chintala's avatar
Soumith Chintala committed
264
265
266
            f_m_minus = bins[m - 1].item()
            f_m = bins[m].item()
            f_m_plus = bins[m + 1].item()
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366

            if f_m_minus != f_m:
                fb[f_m_minus:f_m, m - 1] = (torch.arange(f_m_minus, f_m) - f_m_minus) / (f_m - f_m_minus)
            if f_m != f_m_plus:
                fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)

        fb = Variable(fb)
        spec_m = torch.matmul(spec_f, fb)  # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
        return spec_m if is_variable else spec_m.data


class SPEC2DB(object):
    """Turns a spectrogram from the power/amplitude scale to the decibel scale.

    Args:
        stype (str): scale of input spectrogram ("power" or "magnitude").  The
            power being the elementwise square of the magnitude. default: "power"
        top_db (float, optional): minimum negative cut-off in decibels.  A reasonable number
            is -80.
    """
    def __init__(self, stype="power", top_db=None):
        self.stype = stype
        self.top_db = -top_db if top_db > 0 else top_db
        self.multiplier = 10. if stype == "power" else 20.

    def __call__(self, spec):

        spec, is_variable = _check_is_variable(spec)
        spec_db = self.multiplier * _tlog10(spec / spec.max())  # power -> dB
        if self.top_db is not None:
            spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
        return spec_db if is_variable else spec_db.data


class MEL2(object):
    """Create MEL Spectrograms from a raw audio signal using the stft
       function in PyTorch.  Hopefully this solves the speed issue of using
       librosa.

    Sources:
        * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
        * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html
        * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html

    Args:
        sr (int): sample rate of audio signal
        ws (int): window size, often called the fft size as well
        hop (int, optional): length of hop between STFT windows. default: ws // 2
        n_fft (int, optional): number of fft bins. default: ws // 2 + 1
        pad (int): two sided padding of signal
        n_mels (int): number of MEL bins
        window (torch windowing function): default: torch.hann_window
        wkwargs (dict, optional): arguments for window function

    Example:
        >>> sig, sr = torchaudio.load("test.wav", normalization=True)
        >>> sig = transforms.LC2CL()(sig)  # (n, c) -> (c, n)
        >>> spec_mel = transforms.MEL2(sr)(sig)  # (c, l, m)
    """
    def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
                 pad=0, n_mels=40, window=torch.hann_window, wkwargs=None):
        self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
        self.window = Variable(self.window, requires_grad=False)
        self.sr = sr
        self.ws = ws
        self.hop = hop if hop is not None else ws // 2
        self.n_fft = n_fft  # number of fourier bins (ws // 2 + 1 by default)
        self.pad = pad
        self.n_mels = n_mels  # number of mel frequency bins
        self.wkwargs = wkwargs
        self.top_db = -80.
        self.f_max = None
        self.f_min = 0.

    def __call__(self, sig):
        """
        Args:
            sig (Tensor): Tensor of audio of size (channels [c], samples [n])

        Returns:
            spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
                is unchanged, hops is the number of hops, and n_mels is the
                number of mel bins.

        """

        sig, is_variable = _check_is_variable(sig)

        transforms = Compose([
            SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
                        self.pad, self.window),
            F2M(self.n_mels, self.sr, self.f_max, self.f_min),
            SPEC2DB("power", self.top_db),
        ])

        spec_mel_db = transforms(sig)

        return spec_mel_db if is_variable else spec_mel_db.data


367
368
369
370
371
372
373
374
375
376
377
378
379
380
class MEL(object):
    """Create MEL Spectrograms from a raw audio signal. Relatively pretty slow.

       Usage (see librosa.feature.melspectrogram docs):
           MEL(sr=16000, n_fft=1600, hop_length=800, n_mels=64)
    """

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def __call__(self, tensor):
        """

        Args:
381
            tensor (Tensor): Tensor of audio of size (samples [n] x channels [c])
382
383
384
385
386
387
388

        Returns:
            tensor (Tensor): n_mels x hops x channels (BxLxC), where n_mels is
                the number of mel bins, hops is the number of hops, and channels
                is unchanged.

        """
389

390
391
392
393
394
        if librosa is None:
            print("librosa not installed, cannot create spectrograms")
            return tensor
        L = []
        for i in range(tensor.size(1)):
Soumith Chintala's avatar
Soumith Chintala committed
395
396
397
            nparr = tensor[:, i].numpy()  # (samples, )
            sgram = librosa.feature.melspectrogram(
                nparr, **self.kwargs)  # (n_mels, hops)
398
            L.append(sgram)
Soumith Chintala's avatar
Soumith Chintala committed
399
        L = np.stack(L, 2)  # (n_mels, hops, channels)
400
401
402
403
        tensor = torch.from_numpy(L).type_as(tensor)

        return tensor

404
405
406
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
407

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
class BLC2CBL(object):
    """Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
       Bands x samples (Length)
    """

    def __call__(self, tensor):
        """

        Args:
            tensor (Tensor): Tensor of spectrogram with shape (BxLxC)

        Returns:
            tensor (Tensor): Tensor of spectrogram with shape (CxBxL)

        """

        return tensor.permute(2, 0, 1).contiguous()
David Pollack's avatar
David Pollack committed
425

426
427
428
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
429

David Pollack's avatar
David Pollack committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
class MuLawEncoding(object):
    """Encode signal based on mu-law companding.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
    returns a signal encoded with values from 0 to quantization_channels - 1

    Args:
        quantization_channels (int): Number of channels. default: 256

    """

    def __init__(self, quantization_channels=256):
        self.qc = quantization_channels

    def __call__(self, x):
        """

        Args:
            x (FloatTensor/LongTensor or ndarray)

        Returns:
            x_mu (LongTensor or ndarray)

        """
        mu = self.qc - 1.
        if isinstance(x, np.ndarray):
            x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
            x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
        elif isinstance(x, (torch.Tensor, torch.LongTensor)):
            if isinstance(x, torch.LongTensor):
                x = x.float()
            mu = torch.FloatTensor([mu])
Soumith Chintala's avatar
Soumith Chintala committed
463
464
            x_mu = torch.sign(x) * torch.log1p(mu *
                                               torch.abs(x)) / torch.log1p(mu)
David Pollack's avatar
David Pollack committed
465
466
467
            x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
        return x_mu

468
469
470
    def __repr__(self):
        return self.__class__.__name__ + '()'

Soumith Chintala's avatar
Soumith Chintala committed
471

David Pollack's avatar
David Pollack committed
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
class MuLawExpanding(object):
    """Decode mu-law encoded signal.  For more info see the
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

    Args:
        quantization_channels (int): Number of channels. default: 256

    """

    def __init__(self, quantization_channels=256):
        self.qc = quantization_channels

    def __call__(self, x_mu):
        """

        Args:
            x_mu (FloatTensor/LongTensor or ndarray)

        Returns:
            x (FloatTensor or ndarray)

        """
        mu = self.qc - 1.
        if isinstance(x_mu, np.ndarray):
            x = ((x_mu) / mu) * 2 - 1.
            x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
        elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)):
            if isinstance(x_mu, torch.LongTensor):
                x_mu = x_mu.float()
            mu = torch.FloatTensor([mu])
            x = ((x_mu) / mu) * 2 - 1.
            x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
        return x
508
509
510

    def __repr__(self):
        return self.__class__.__name__ + '()'