functional.py 44.5 KB
Newer Older
1
from __future__ import absolute_import, division, print_function, unicode_literals
Vincent QB's avatar
Vincent QB committed
2

3
import math
Vincent QB's avatar
Vincent QB committed
4

Jason Lian's avatar
Jason Lian committed
5
6
import torch

Jason Lian's avatar
pre  
Jason Lian committed
7
__all__ = [
8
9
    "istft",
    "spectrogram",
10
    "griffinlim",
11
12
13
14
15
16
17
18
19
20
21
22
    "amplitude_to_DB",
    "create_fb_matrix",
    "create_dct",
    "mu_law_encoding",
    "mu_law_decoding",
    "complex_norm",
    "angle",
    "magphase",
    "phase_vocoder",
    "lfilter",
    "lowpass_biquad",
    "highpass_biquad",
xinyang0's avatar
xinyang0 committed
23
    "equalizer_biquad",
24
    "biquad",
25
26
    'mask_along_axis',
    'mask_along_axis_iid'
Jason Lian's avatar
pre  
Jason Lian committed
27
28
]

Vincent QB's avatar
Vincent QB committed
29

Jason Lian's avatar
Jason Lian committed
30
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
Jason Lian's avatar
more  
Jason Lian committed
31
@torch.jit.ignore
32
33
34
35
36
37
38
39
40
41
42
def _stft(
    waveform,
    n_fft,
    hop_length,
    win_length,
    window,
    center,
    pad_mode,
    normalized,
    onesided,
):
43
    # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    return torch.stft(
        waveform,
        n_fft,
        hop_length,
        win_length,
        window,
        center,
        pad_mode,
        normalized,
        onesided,
    )


def istft(
    stft_matrix,  # type: Tensor
    n_fft,  # type: int
    hop_length=None,  # type: Optional[int]
    win_length=None,  # type: Optional[int]
    window=None,  # type: Optional[Tensor]
    center=True,  # type: bool
    pad_mode="reflect",  # type: str
    normalized=False,  # type: bool
    onesided=True,  # type: bool
    length=None,  # type: Optional[int]
):
jamarshon's avatar
jamarshon committed
69
    # type: (...) -> Tensor
70
    r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
jamarshon's avatar
jamarshon committed
71
    It has the same parameters (+ additional optional parameter of ``length``) and it should return the
jamarshon's avatar
jamarshon committed
72
73
    least squares estimation of the original signal. The algorithm will check using the NOLA condition (
    nonzero overlap).
jamarshon's avatar
jamarshon committed
74
75

    Important consideration in the parameters ``window`` and ``center`` so that the envelop
jamarshon's avatar
jamarshon committed
76
    created by the summation of all the windows is never zero at certain point in time. Specifically,
jamarshon's avatar
jamarshon committed
77
78
    :math:`\sum_{t=-\infty}^{\infty} w^2[n-t\times hop\_length] \cancel{=} 0`.

jamarshon's avatar
jamarshon committed
79
    Since stft discards elements at the end of the signal if they do not fit in a frame, the
80
    istft may return a shorter signal than the original signal (can occur if ``center`` is False
jamarshon's avatar
jamarshon committed
81
    since the signal isn't padded).
jamarshon's avatar
jamarshon committed
82
83

    If ``center`` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding
jamarshon's avatar
jamarshon committed
84
85
    can be trimmed off exactly because they can be calculated but right padding cannot be calculated
    without additional information.
jamarshon's avatar
jamarshon committed
86

jamarshon's avatar
jamarshon committed
87
88
    Example: Suppose the last window is:
    [17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
jamarshon's avatar
jamarshon committed
89

Vincent QB's avatar
Vincent QB committed
90
    The n_frame, hop_length, win_length are all the same which prevents the calculation of right padding.
jamarshon's avatar
jamarshon committed
91
92
93
94
    These additional values could be zeros or a reflection of the signal so providing ``length``
    could be useful. If ``length`` is ``None`` then padding will be aggressively removed
    (some loss of signal).

95
96
    [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
    IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
jamarshon's avatar
jamarshon committed
97
98

    Args:
99
        stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
100
            column is a window. it has a size of either (..., fft_size, n_frame, 2)
jamarshon's avatar
jamarshon committed
101
102
103
104
105
106
107
        n_fft (int): Size of Fourier transform
        hop_length (Optional[int]): The distance between neighboring sliding window frames.
            (Default: ``win_length // 4``)
        win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
        window (Optional[torch.Tensor]): The optional window function.
            (Default: ``torch.ones(win_length)``)
        center (bool): Whether ``input`` was padded on both sides so
108
109
110
111
112
113
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            (Default: ``True``)
        pad_mode (str): Controls the padding method used when ``center`` is True. (Default:
            ``'reflect'``)
        normalized (bool): Whether the STFT was normalized. (Default: ``False``)
        onesided (bool): Whether the STFT is onesided. (Default: ``True``)
jamarshon's avatar
jamarshon committed
114
        length (Optional[int]): The amount to trim the signal by (i.e. the
jamarshon's avatar
jamarshon committed
115
            original signal length). (Default: whole signal)
jamarshon's avatar
jamarshon committed
116
117

    Returns:
Vincent QB's avatar
Vincent QB committed
118
        torch.Tensor: Least squares estimation of the original signal of size (..., signal_length)
jamarshon's avatar
jamarshon committed
119
120
    """
    stft_matrix_dim = stft_matrix.dim()
Vincent QB's avatar
Vincent QB committed
121
    assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
122
    assert stft_matrix.numel() > 0
jamarshon's avatar
jamarshon committed
123
124

    if stft_matrix_dim == 3:
125
        # add a channel dimension
jamarshon's avatar
jamarshon committed
126
127
        stft_matrix = stft_matrix.unsqueeze(0)

Vincent QB's avatar
Vincent QB committed
128
129
    # pack batch
    shape = stft_matrix.size()
130
    stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])
Vincent QB's avatar
Vincent QB committed
131

132
    dtype = stft_matrix.dtype
jamarshon's avatar
jamarshon committed
133
134
    device = stft_matrix.device
    fft_size = stft_matrix.size(1)
135
136
137
138
139
140
141
    assert (onesided and n_fft // 2 + 1 == fft_size) or (
        not onesided and n_fft == fft_size
    ), (
        "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. "
        + "Given values were onesided: %s, n_fft: %d, fft_size: %d"
        % ("True" if onesided else False, n_fft, fft_size)
    )
jamarshon's avatar
jamarshon committed
142
143
144
145
146
147
148
149
150
151
152
153
154

    # use stft defaults for Optionals
    if win_length is None:
        win_length = n_fft

    if hop_length is None:
        hop_length = int(win_length // 4)

    # There must be overlap
    assert 0 < hop_length <= win_length
    assert 0 < win_length <= n_fft

    if window is None:
155
156
        window = torch.ones(win_length)
        window.to(device=device, dtype=dtype)
jamarshon's avatar
jamarshon committed
157
158
159
160
161
162
163
164
165
166

    assert window.dim() == 1 and window.size(0) == win_length

    if win_length != n_fft:
        # center window with pad left and right zeros
        left = (n_fft - win_length) // 2
        window = torch.nn.functional.pad(window, (left, n_fft - win_length - left))
        assert window.size(0) == n_fft
    # win_length and n_fft are synonymous from here on

Vincent QB's avatar
Vincent QB committed
167
    stft_matrix = stft_matrix.transpose(1, 2)  # size (channel, n_frame, fft_size, 2)
168
169
    stft_matrix = torch.irfft(
        stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,)
Vincent QB's avatar
Vincent QB committed
170
    )  # size (channel, n_frame, n_fft)
jamarshon's avatar
jamarshon committed
171
172

    assert stft_matrix.size(2) == n_fft
Vincent QB's avatar
Vincent QB committed
173
    n_frame = stft_matrix.size(1)
jamarshon's avatar
jamarshon committed
174

Vincent QB's avatar
Vincent QB committed
175
    ytmp = stft_matrix * window.view(1, 1, n_fft)  # size (channel, n_frame, n_fft)
176
    # each column of a channel is a frame which needs to be overlap added at the right place
Vincent QB's avatar
Vincent QB committed
177
    ytmp = ytmp.transpose(1, 2)  # size (channel, n_fft, n_frame)
jamarshon's avatar
jamarshon committed
178

179
180
    eye = torch.eye(n_fft)
    eye = eye.to(device=device, dtype=dtype).unsqueeze(1)  # size (n_fft, 1, n_fft)
jamarshon's avatar
jamarshon committed
181
182
183
184

    # this does overlap add where the frames of ytmp are added such that the i'th frame of
    # ytmp is added starting at i*hop_length in the output
    y = torch.nn.functional.conv_transpose1d(
185
186
        ytmp, eye, stride=hop_length, padding=0
    )  # size (channel, 1, expected_signal_len)
jamarshon's avatar
jamarshon committed
187
188

    # do the same for the window function
189
    window_sq = (
Vincent QB's avatar
Vincent QB committed
190
191
        window.pow(2).view(n_fft, 1).repeat((1, n_frame)).unsqueeze(0)
    )  # size (1, n_fft, n_frame)
jamarshon's avatar
jamarshon committed
192
    window_envelop = torch.nn.functional.conv_transpose1d(
193
194
        window_sq, eye, stride=hop_length, padding=0
    )  # size (1, 1, expected_signal_len)
jamarshon's avatar
jamarshon committed
195

Vincent QB's avatar
Vincent QB committed
196
    expected_signal_len = n_fft + hop_length * (n_frame - 1)
jamarshon's avatar
jamarshon committed
197
198
199
200
201
202
203
204
205
206
207
208
209
    assert y.size(2) == expected_signal_len
    assert window_envelop.size(2) == expected_signal_len

    half_n_fft = n_fft // 2
    # we need to trim the front padding away if center
    start = half_n_fft if center else 0
    end = -half_n_fft if length is None else start + length

    y = y[:, :, start:end]
    window_envelop = window_envelop[:, :, start:end]

    # check NOLA non-zero overlap condition
    window_envelop_lowest = window_envelop.abs().min()
210
211
212
    assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % (
        window_envelop_lowest
    )
jamarshon's avatar
jamarshon committed
213

214
    y = (y / window_envelop).squeeze(1)  # size (channel, expected_signal_len)
jamarshon's avatar
jamarshon committed
215

Vincent QB's avatar
Vincent QB committed
216
217
218
    # unpack batch
    y = y.reshape(shape[:-3] + y.shape[-1:])

219
    if stft_matrix_dim == 3:  # remove the channel dimension
jamarshon's avatar
jamarshon committed
220
        y = y.squeeze(0)
Vincent QB's avatar
Vincent QB committed
221

jamarshon's avatar
jamarshon committed
222
223
224
    return y


225
226
227
def spectrogram(
    waveform, pad, window, n_fft, hop_length, win_length, power, normalized
):
228
    # type: (Tensor, int, Tensor, int, int, int, Optional[int], bool) -> Tensor
Vincent QB's avatar
Vincent QB committed
229
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
230
    The spectrogram can be either magnitude-only or complex.
jamarshon's avatar
jamarshon committed
231
232

    Args:
233
        waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time)
jamarshon's avatar
jamarshon committed
234
        pad (int): Two sided padding of signal
235
        window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
236
        n_fft (int): Size of FFT
237
238
239
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
        power (int): Exponent for the magnitude spectrogram,
jamarshon's avatar
jamarshon committed
240
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
241
            If None, then the complex spectrum is returned instead.
242
        normalized (bool): Whether to normalize by magnitude after stft
jamarshon's avatar
jamarshon committed
243
244

    Returns:
245
        torch.Tensor: Dimension (..., channel, freq, time), where channel
Vincent QB's avatar
Vincent QB committed
246
247
        is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of
        Fourier bins, and time is the number of window hops (n_frame).
Jason Lian's avatar
Jason Lian committed
248
    """
Jason Lian's avatar
Jason Lian committed
249
250

    if pad > 0:
251
        # TODO add "with torch.no_grad():" back when JIT supports it
252
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
253

254
255
256
257
    # pack batch
    shape = waveform.size()
    waveform = waveform.reshape(-1, shape[-1])

Jason Lian's avatar
Jason Lian committed
258
    # default values are consistent with librosa.core.spectrum._spectrogram
259
260
261
    spec_f = _stft(
        waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True
    )
262

263
264
265
    # unpack batch
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])

266
    if normalized:
Jason Lian's avatar
Jason Lian committed
267
        spec_f /= window.pow(2).sum().sqrt()
268
269
270
    if power is not None:
        spec_f = spec_f.pow(power).sum(-1)  # get power of "complex" tensor

Jason Lian's avatar
Jason Lian committed
271
    return spec_f
Jason Lian's avatar
more  
Jason Lian committed
272
273


274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def griffinlim(
    spectrogram, window, n_fft, hop_length, win_length, power, normalized, n_iter, momentum, length, rand_init
):
    # type: (Tensor, Tensor, int, int, int, int, bool, int, float, Optional[int], bool) -> Tensor
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

    .. [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.

    .. [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.

    .. [3] D. W. Griffin and J. S. Lim,
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
        spectrogram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames)
            where freq is ``n_fft // 2 + 1``.
        window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
        power (int): Exponent for the magnitude spectrogram,
            (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
        normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
            Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
        length (Optional[int]): Array length of the expected output. (Default: ``None``)
        rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``)

    Returns:
        torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given.
    """
    assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
    assert momentum > 0, 'momentum=%s < 0' % momentum

    spectrogram = spectrogram.pow(1 / power)

    # randomly initialize the phase
    batch, freq, frames = spectrogram.size()
    if rand_init:
        angles = 2 * math.pi * torch.rand(batch, freq, frames)
    else:
        angles = torch.zeros(batch, freq, frames)
    angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
                  .to(dtype=spectrogram.dtype, device=spectrogram.device)
    spectrogram = spectrogram.unsqueeze(-1).expand_as(angles)

    # And initialize the previous iterate to 0
    rebuilt = torch.tensor(0.)

    for _ in range(n_iter):
        # Store the previous iterate
        tprev = rebuilt

        # Invert with our current estimate of the phases
        inverse = istft(spectrogram * angles,
                        n_fft=n_fft,
                        hop_length=hop_length,
                        win_length=win_length,
                        window=window,
                        length=length).float()

        # Rebuild the spectrogram
        rebuilt = _stft(inverse, n_fft, hop_length, win_length, window,
                        True, 'reflect', False, True)

        # Update our phase estimates
        angles = rebuilt - tprev.mul_(momentum / (1 + momentum))
        angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles))

    # Return the final phase estimates
    return istft(spectrogram * angles,
                 n_fft=n_fft,
                 hop_length=hop_length,
                 win_length=win_length,
                 window=window,
                 length=length)


362
def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
363
    # type: (Tensor, float, float, float, Optional[float]) -> Tensor
Vincent QB's avatar
Vincent QB committed
364
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
365

366
    This output depends on the maximum value in the input tensor, and so
367
368
369
370
    may return different values for an audio clip split into snippets vs. a
    a full clip.

    Args:
371
        x (torch.Tensor): Input tensor before being converted to decibel scale
372
        multiplier (float): Use 10. for power and 20. for amplitude
373
        amin (float): Number to clamp ``x``
374
375
        db_multiplier (float): Log10(max(reference value and amin))
        top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
376
            is 80. (Default: ``None``)
377
378

    Returns:
379
        torch.Tensor: Output tensor in decibel scale
380
    """
381
382
    x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
    x_db -= multiplier * db_multiplier
383
384

    if top_db is not None:
385
        x_db = x_db.clamp(min=x_db.max().item() - top_db)
386

387
    return x_db
388
389


engineerchuan's avatar
engineerchuan committed
390
391
def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
    # type: (int, float, float, int, int) -> Tensor
Vincent QB's avatar
Vincent QB committed
392
    r"""Create a frequency bin conversion matrix.
Jason Lian's avatar
more  
Jason Lian committed
393

jamarshon's avatar
jamarshon committed
394
    Args:
395
        n_freqs (int): Number of frequencies to highlight/apply
engineerchuan's avatar
engineerchuan committed
396
397
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
398
        n_mels (int): Number of mel filterbanks
engineerchuan's avatar
engineerchuan committed
399
        sample_rate (int): Sample rate of the audio waveform
Jason Lian's avatar
more  
Jason Lian committed
400

jamarshon's avatar
jamarshon committed
401
    Returns:
402
        torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
403
404
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
405
406
        size (..., ``n_freqs``), the applied result would be
        ``A * create_fb_matrix(A.size(-1), ...)``.
407
    """
408
    # freq bins
engineerchuan's avatar
engineerchuan committed
409
410
411
412
413
    # Equivalent filterbank construction by Librosa
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
    i_freqs = all_freqs.ge(f_min) & all_freqs.le(f_max)
    freqs = all_freqs[i_freqs]

Jason Lian's avatar
more  
Jason Lian committed
414
    # calculate mel freq bins
415
    # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
engineerchuan's avatar
engineerchuan committed
416
    m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
417
    m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
Jason Lian's avatar
more  
Jason Lian committed
418
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
419
    # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
420
    f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
Jason Lian's avatar
more  
Jason Lian committed
421
422
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
engineerchuan's avatar
engineerchuan committed
423
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_mels + 2)
Jason Lian's avatar
more  
Jason Lian committed
424
    # create overlapping triangles
425
    zero = torch.zeros(1)
426
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_mels)
427
428
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_mels)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))
Jason Lian's avatar
more  
Jason Lian committed
429
430
431
    return fb


Jason Lian's avatar
more  
Jason Lian committed
432
def create_dct(n_mfcc, n_mels, norm):
433
    # type: (int, int, Optional[str]) -> Tensor
Vincent QB's avatar
Vincent QB committed
434
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
jamarshon's avatar
jamarshon committed
435
    normalized depending on norm.
Jason Lian's avatar
Jason Lian committed
436

jamarshon's avatar
jamarshon committed
437
    Args:
438
439
440
        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
        norm (Optional[str]): Norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
441

jamarshon's avatar
jamarshon committed
442
    Returns:
443
        torch.Tensor: The transformation matrix, to be right-multiplied to
444
        row-wise data of size (``n_mels``, ``n_mfcc``).
Jason Lian's avatar
more  
Jason Lian committed
445
446
    """
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
447
448
449
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
450
451
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
452
    else:
453
        assert norm == "ortho"
454
        dct[0] *= 1.0 / math.sqrt(2.0)
455
        dct *= math.sqrt(2.0 / float(n_mels))
456
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
457
458


459
def mu_law_encoding(x, quantization_channels):
460
    # type: (Tensor, int) -> Tensor
Vincent QB's avatar
Vincent QB committed
461
    r"""Encode signal based on mu-law companding.  For more info see the
Jason Lian's avatar
Jason Lian committed
462
463
464
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
jamarshon's avatar
jamarshon committed
465
    returns a signal encoded with values from 0 to quantization_channels - 1.
Jason Lian's avatar
Jason Lian committed
466

jamarshon's avatar
jamarshon committed
467
468
    Args:
        x (torch.Tensor): Input tensor
469
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
470

jamarshon's avatar
jamarshon committed
471
    Returns:
472
        torch.Tensor: Input after mu-law encoding
Jason Lian's avatar
Jason Lian committed
473
    """
474
    mu = quantization_channels - 1.0
475
    if not x.is_floating_point():
476
477
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
478
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
479
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
480
481
482
    return x_mu


483
def mu_law_decoding(x_mu, quantization_channels):
484
    # type: (Tensor, int) -> Tensor
Vincent QB's avatar
Vincent QB committed
485
    r"""Decode mu-law encoded signal.  For more info see the
Jason Lian's avatar
Jason Lian committed
486
487
488
489
490
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

jamarshon's avatar
jamarshon committed
491
492
    Args:
        x_mu (torch.Tensor): Input tensor
493
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
494

jamarshon's avatar
jamarshon committed
495
    Returns:
496
        torch.Tensor: Input after mu-law decoding
Jason Lian's avatar
Jason Lian committed
497
    """
498
    mu = quantization_channels - 1.0
499
    if not x_mu.is_floating_point():
500
501
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
502
503
    x = ((x_mu) / mu) * 2 - 1.0
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
Jason Lian's avatar
more  
Jason Lian committed
504
    return x
505
506
507


def complex_norm(complex_tensor, power=1.0):
508
    # type: (Tensor, float) -> Tensor
509
    r"""Compute the norm of complex tensor input.
510
511

    Args:
512
        complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
513
        power (float): Power of the norm. (Default: `1.0`).
514
515

    Returns:
516
        torch.Tensor: Power of the normed input tensor. Shape of `(..., )`
517
518
519
520
521
522
523
    """
    if power == 1.0:
        return torch.norm(complex_tensor, 2, -1)
    return torch.norm(complex_tensor, 2, -1).pow(power)


def angle(complex_tensor):
524
    # type: (Tensor) -> Tensor
525
526
527
    r"""Compute the angle of complex tensor input.

    Args:
528
        complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
529
530

    Return:
531
        torch.Tensor: Angle of a complex tensor. Shape of `(..., )`
532
533
534
535
    """
    return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])


536
def magphase(complex_tensor, power=1.0):
537
    # type: (Tensor, float) -> Tuple[Tensor, Tensor]
538
    r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
539
540

    Args:
541
        complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
542
543
544
        power (float): Power of the norm. (Default: `1.0`)

    Returns:
545
        Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex tensor
546
547
548
549
550
551
552
    """
    mag = complex_norm(complex_tensor, power)
    phase = angle(complex_tensor)
    return mag, phase


def phase_vocoder(complex_specgrams, rate, phase_advance):
553
    # type: (Tensor, float, Tensor) -> Tensor
554
    r"""Given a STFT tensor, speed up in time without modifying pitch by a
555
    factor of ``rate``.
Vincent QB's avatar
Vincent QB committed
556

557
    Args:
558
        complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)`
559
        rate (float): Speed-up factor
560
561
        phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
            of (freq, 1)
Vincent QB's avatar
Vincent QB committed
562

563
    Returns:
564
        complex_specgrams_stretch (torch.Tensor): Dimension of `(...,
565
        freq, ceil(time/rate), complex=2)`
Vincent QB's avatar
Vincent QB committed
566

567
    Example
568
569
570
571
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time, complex=2)
        >>> complex_specgrams = torch.randn(2, freq, 300, 2)
        >>> rate = 1.3 # Speed up by 30%
572
        >>> phase_advance = torch.linspace(
573
        >>>    0, math.pi * hop_length, freq)[..., None]
574
575
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
576
        torch.Size([2, 1025, 231, 2])
577
    """
578

579
580
581
582
    # pack batch
    shape = complex_specgrams.size()
    complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))

583
584
585
586
587
    time_steps = torch.arange(0,
                              complex_specgrams.size(-2),
                              rate,
                              device=complex_specgrams.device,
                              dtype=complex_specgrams.dtype)
588

589
    alphas = time_steps % 1.0
Vincent QB's avatar
Vincent QB committed
590
    phase_0 = angle(complex_specgrams[..., :1, :])
591
592
593
594

    # Time Padding
    complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])

595
    # (new_bins, freq, 2)
Vincent QB's avatar
Vincent QB committed
596
597
    complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
    complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())
598
599
600
601

    angle_0 = angle(complex_specgrams_0)
    angle_1 = angle(complex_specgrams_1)

602
603
    norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
    norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)
604
605
606
607
608
609

    phase = angle_1 - angle_0 - phase_advance
    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

    # Compute Phase Accum
    phase = phase + phase_advance
Vincent QB's avatar
Vincent QB committed
610
    phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
611
612
613
614
615
616
617
618
619
    phase_acc = torch.cumsum(phase, -1)

    mag = alphas * norm_1 + (1 - alphas) * norm_0

    real_stretch = mag * torch.cos(phase_acc)
    imag_stretch = mag * torch.sin(phase_acc)

    complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)

620
621
622
    # unpack batch
    complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])

623
    return complex_specgrams_stretch
624
625
626
627


def lfilter(waveform, a_coeffs, b_coeffs):
    # type: (Tensor, Tensor, Tensor) -> Tensor
Vincent QB's avatar
Vincent QB committed
628
    r"""Perform an IIR filter by evaluating difference equation.
629
630

    Args:
Vincent QB's avatar
Vincent QB committed
631
        waveform (torch.Tensor): audio waveform of dimension of `(..., time)`.  Must be normalized to -1 to 1.
632
633
634
635
636
637
638
639
        a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
                                Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
                                Must be same size as b_coeffs (pad with 0's as necessary).
        b_coeffs (torch.Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`.
                                 Lower delays coefficients are first, e.g. `[b0, b1, b2, ...]`.
                                 Must be same size as a_coeffs (pad with 0's as necessary).

    Returns:
Vincent QB's avatar
Vincent QB committed
640
        output_waveform (torch.Tensor): Dimension of `(..., time)`.  Output will be clipped to -1 to 1.
641
642
643

    """

Vincent QB's avatar
Vincent QB committed
644
645
646
647
648
649
    dim = waveform.dim()

    # pack batch
    shape = waveform.size()
    waveform = waveform.reshape(-1, shape[-1])

650
651
    assert(a_coeffs.size(0) == b_coeffs.size(0))
    assert(len(waveform.size()) == 2)
652
653
    assert(waveform.device == a_coeffs.device)
    assert(b_coeffs.device == a_coeffs.device)
654

655
656
    device = waveform.device
    dtype = waveform.dtype
Vincent QB's avatar
Vincent QB committed
657
    n_channel, n_sample = waveform.size()
658
    n_order = a_coeffs.size(0)
659
    n_sample_padded = n_sample + n_order - 1
660
661
662
    assert(n_order > 0)

    # Pad the input and create output
663
    padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
664
    padded_waveform[:, (n_order - 1):] = waveform
665
    padded_output_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
666
667

    # Set up the coefficients matrix
668
669
670
671
672
673
674
675
676
677
678
679
680
681
    # Flip coefficients' order
    a_coeffs_flipped = a_coeffs.flip(0)
    b_coeffs_flipped = b_coeffs.flip(0)

    # calculate windowed_input_signal in parallel
    # create indices of original with shape (n_channel, n_order, n_sample)
    window_idxs = torch.arange(n_sample, device=device).unsqueeze(0) + torch.arange(n_order, device=device).unsqueeze(1)
    window_idxs = window_idxs.repeat(n_channel, 1, 1)
    window_idxs += (torch.arange(n_channel, device=device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
    window_idxs = window_idxs.long()
    # (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
    input_signal_windows = torch.matmul(b_coeffs_flipped, torch.take(padded_waveform, window_idxs))

    for i_sample, o0 in enumerate(input_signal_windows.t()):
Vincent QB's avatar
Vincent QB committed
682
        windowed_output_signal = padded_output_waveform[:, i_sample:(i_sample + n_order)]
683
684
        o0.sub_(torch.mv(windowed_output_signal, a_coeffs_flipped))
        o0.div_(a_coeffs[0])
685

Vincent QB's avatar
Vincent QB committed
686
        padded_output_waveform[:, i_sample + n_order - 1] = o0
687

688
    output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.)
Vincent QB's avatar
Vincent QB committed
689
690
691
692
693

    # unpack batch
    output = output.reshape(shape[:-1] + output.shape[-1:])

    return output
694
695
696
697


def biquad(waveform, b0, b1, b2, a0, a1, a2):
    # type: (Tensor, float, float, float, float, float, float) -> Tensor
Vincent QB's avatar
Vincent QB committed
698
    r"""Perform a biquad filter of input tensor.  Initial conditions set to 0.
699
700
701
    https://en.wikipedia.org/wiki/Digital_biquad_filter

    Args:
Vincent QB's avatar
Vincent QB committed
702
        waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
703
704
705
706
707
708
709
710
        b0 (float): numerator coefficient of current input, x[n]
        b1 (float): numerator coefficient of input one time step ago x[n-1]
        b2 (float): numerator coefficient of input two time steps ago x[n-2]
        a0 (float): denominator coefficient of current output y[n], typically 1
        a1 (float): denominator coefficient of current output y[n-1]
        a2 (float): denominator coefficient of current output y[n-2]

    Returns:
Vincent QB's avatar
Vincent QB committed
711
        output_waveform (torch.Tensor): Dimension of `(channel, time)`
712
713
    """

714
715
    device = waveform.device
    dtype = waveform.dtype
716
717

    output_waveform = lfilter(
718
719
720
        waveform,
        torch.tensor([a0, a1, a2], dtype=dtype, device=device),
        torch.tensor([b0, b1, b2], dtype=dtype, device=device)
721
722
723
724
725
    )
    return output_waveform


def _dB2Linear(x):
726
    # type: (float) -> float
727
728
729
730
    return math.exp(x * math.log(10) / 20.0)


def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
731
    # type: (Tensor, int, float, float) -> Tensor
Vincent QB's avatar
Vincent QB committed
732
    r"""Design biquad highpass filter and perform filtering.  Similar to SoX implementation.
733
734

    Args:
Vincent QB's avatar
Vincent QB committed
735
        waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
736
737
738
739
740
        sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
        cutoff_freq (float): filter cutoff frequency
        Q (float): https://en.wikipedia.org/wiki/Q_factor

    Returns:
Vincent QB's avatar
Vincent QB committed
741
        output_waveform (torch.Tensor): Dimension of `(channel, time)`
742
743
    """

744
    GAIN = 1.
745
746
    w0 = 2 * math.pi * cutoff_freq / sample_rate
    A = math.exp(GAIN / 40.0 * math.log(10))
747
    alpha = math.sin(w0) / 2. / Q
748
749
750
751
752
753
754
755
756
757
758
759
    mult = _dB2Linear(max(GAIN, 0))

    b0 = (1 + math.cos(w0)) / 2
    b1 = -1 - math.cos(w0)
    b2 = b0
    a0 = 1 + alpha
    a1 = -2 * math.cos(w0)
    a2 = 1 - alpha
    return biquad(waveform, b0, b1, b2, a0, a1, a2)


def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
760
    # type: (Tensor, int, float, float) -> Tensor
Vincent QB's avatar
Vincent QB committed
761
    r"""Design biquad lowpass filter and perform filtering.  Similar to SoX implementation.
762
763

    Args:
Vincent QB's avatar
Vincent QB committed
764
        waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
765
766
767
768
769
        sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
        cutoff_freq (float): filter cutoff frequency
        Q (float): https://en.wikipedia.org/wiki/Q_factor

    Returns:
Vincent QB's avatar
Vincent QB committed
770
        output_waveform (torch.Tensor): Dimension of `(channel, time)`
771
772
    """

773
    GAIN = 1.
774
775
776
777
778
779
780
781
782
783
784
785
    w0 = 2 * math.pi * cutoff_freq / sample_rate
    A = math.exp(GAIN / 40.0 * math.log(10))
    alpha = math.sin(w0) / 2 / Q
    mult = _dB2Linear(max(GAIN, 0))

    b0 = (1 - math.cos(w0)) / 2
    b1 = 1 - math.cos(w0)
    b2 = b0
    a0 = 1 + alpha
    a1 = -2 * math.cos(w0)
    a2 = 1 - alpha
    return biquad(waveform, b0, b1, b2, a0, a1, a2)
Vincent QB's avatar
Vincent QB committed
786
787


xinyang0's avatar
xinyang0 committed
788
789
def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
    # type: (Tensor, int, float, float, float) -> Tensor
Vincent QB's avatar
Vincent QB committed
790
    r"""Design biquad peaking equalizer filter and perform filtering.  Similar to SoX implementation.
xinyang0's avatar
xinyang0 committed
791
792
793
794

    Args:
        waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`
        sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
795
        center_freq (float): filter's central frequency
xinyang0's avatar
xinyang0 committed
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
        gain (float): desired gain at the boost (or attenuation) in dB
        q_factor (float): https://en.wikipedia.org/wiki/Q_factor

    Returns:
        output_waveform (torch.Tensor): Dimension of `(channel, time)`
    """
    w0 = 2 * math.pi * center_freq / sample_rate
    A = math.exp(gain / 40.0 * math.log(10))
    alpha = math.sin(w0) / 2 / Q

    b0 = 1 + alpha * A
    b1 = -2 * math.cos(w0)
    b2 = 1 - alpha * A
    a0 = 1 + alpha / A
    a1 = -2 * math.cos(w0)
    a2 = 1 - alpha / A
    return biquad(waveform, b0, b1, b2, a0, a1, a2)


815
816
817
818
819
820
821
822
def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
    # type: (Tensor, int, float, int) -> Tensor
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
    All examples will have the same mask interval.

    Args:
Vincent QB's avatar
Vincent QB committed
823
        specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
824
825
826
827
828
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)

    Returns:
Vincent QB's avatar
Vincent QB committed
829
        torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
    """

    if axis != 2 and axis != 3:
        raise ValueError('Only Frequency and Time masking are supported')

    value = torch.rand(specgrams.shape[:2]) * mask_param
    min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value)

    # Create broadcastable mask
    mask_start = (min_value.long())[..., None, None].float()
    mask_end = (min_value.long() + value.long())[..., None, None].float()
    mask = torch.arange(0, specgrams.size(axis)).float()

    # Per batch example masking
    specgrams = specgrams.transpose(axis, -1)
    specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
    specgrams = specgrams.transpose(axis, -1)

    return specgrams


def mask_along_axis(specgram, mask_param, mask_value, axis):
    # type: (Tensor, int, float, int) -> Tensor
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
    All examples will have the same mask interval.

    Args:
Vincent QB's avatar
Vincent QB committed
859
        specgram (Tensor): Real spectrogram (channel, freq, time)
860
861
862
863
864
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

    Returns:
Vincent QB's avatar
Vincent QB committed
865
        torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
866
867
    """

868
869
870
871
    # pack batch
    shape = specgram.size()
    specgram = specgram.reshape([-1] + list(shape[-2:]))

872
873
874
875
876
877
878
879
880
881
882
883
884
885
    value = torch.rand(1) * mask_param
    min_value = torch.rand(1) * (specgram.size(axis) - value)

    mask_start = (min_value.long()).squeeze()
    mask_end = (min_value.long() + value.long()).squeeze()

    assert mask_end - mask_start < mask_param
    if axis == 1:
        specgram[:, mask_start:mask_end] = mask_value
    elif axis == 2:
        specgram[:, :, mask_start:mask_end] = mask_value
    else:
        raise ValueError('Only Frequency and Time masking are supported')

886
887
888
889
    # unpack batch
    specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])

    return specgram.reshape(shape[:-2] + specgram.shape[-2:])
890
891


Vincent QB's avatar
Vincent QB committed
892
893
894
895
896
897
898
899
900
901
902
903
def compute_deltas(specgram, win_length=5, mode="replicate"):
    # type: (Tensor, int, str) -> Tensor
    r"""Compute delta coefficients of a tensor, usually a spectrogram:

    .. math::
        d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N} n^2}

    where :math:`d_t` is the deltas at time :math:`t`,
    :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
    :math:`N` is (`win_length`-1)//2.

    Args:
Vincent QB's avatar
Vincent QB committed
904
        specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
905
906
907
908
        win_length (int): The window length used for computing delta
        mode (str): Mode parameter passed to padding

    Returns:
Vincent QB's avatar
Vincent QB committed
909
        deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
910
911
912
913
914
915
916

    Example
        >>> specgram = torch.randn(1, 40, 1000)
        >>> delta = compute_deltas(specgram)
        >>> delta2 = compute_deltas(delta)
    """

Vincent QB's avatar
Vincent QB committed
917
918
919
920
    # pack batch
    shape = specgram.size()
    specgram = specgram.reshape(1, -1, shape[-1])

Vincent QB's avatar
Vincent QB committed
921
922
923
924
925
926
927
928
929
930
931
932
    assert win_length >= 3

    n = (win_length - 1) // 2

    # twice sum of integer squared
    denom = n * (n + 1) * (2 * n + 1) / 3

    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)

    kernel = (
        torch
        .arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
Vincent QB's avatar
Vincent QB committed
933
        .repeat(specgram.shape[1], 1, 1)
Vincent QB's avatar
Vincent QB committed
934
935
    )

Vincent QB's avatar
Vincent QB committed
936
937
938
939
940
941
    output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom

    # unpack batch
    output = output.reshape(shape)

    return output
Vincent QB's avatar
Vincent QB committed
942
943


944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
def gain(waveform, gain_db=1.0):
    # type: (Tensor, float) -> Tensor
    r"""Apply amplification or attenuation to the whole waveform.

    Args:
       waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
       gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).

    Returns:
       torch.Tensor: the whole waveform amplified by gain_db.
    """
    if (gain_db == 0):
        return waveform

    ratio = 10 ** (gain_db / 20)

    return waveform * ratio


def _add_noise_shaping(dithered_waveform, waveform):
    r"""Noise shaping is calculated by error:
    error[n] = dithered[n] - original[n]
    noise_shaped_waveform[n] = dithered[n] + error[n-1]
    """
    wf_shape = waveform.size()
    waveform = waveform.reshape(-1, wf_shape[-1])

    dithered_shape = dithered_waveform.size()
    dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])

    error = dithered_waveform - waveform

    # add error[n-1] to dithered_waveform[n], so offset the error by 1 index
    for index in range(error.size()[0]):
        err = error[index]
        error_offset = torch.cat((torch.zeros(1), err))
        error[index] = error_offset[:waveform.size()[1]]

    noise_shaped = dithered_waveform + error
    return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])


def _apply_probability_distribution(waveform, density_function="TPDF"):
    # type: (Tensor, str) -> Tensor
    r"""Apply a probability distribution function on a waveform.

    Triangular probability density function (TPDF) dither noise has a
    triangular distribution; values in the center of the range have a higher
    probability of occurring.

    Rectangular probability density function (RPDF) dither noise has a
    uniform distribution; any value in the specified range has the same
    probability of occurring.

    Gaussian probability density function (GPDF) has a normal distribution.
    The relationship of probabilities of results follows a bell-shaped,
    or Gaussian curve, typical of dither generated by analog sources.
    Args:
        waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
        probability_density_function (string): The density function of a
           continuous random variable (Default: `TPDF`)
           Options: Triangular Probability Density Function - `TPDF`
                    Rectangular Probability Density Function - `RPDF`
                    Gaussian Probability Density Function - `GPDF`
    Returns:
        torch.Tensor: waveform dithered with TPDF
    """
    shape = waveform.size()
    waveform = waveform.reshape(-1, shape[-1])

    channel_size = waveform.size()[0] - 1
    time_size = waveform.size()[-1] - 1

    random_channel = int(torch.randint(channel_size, [1, ]).item()) if channel_size > 0 else 0
    random_time = int(torch.randint(time_size, [1, ]).item()) if time_size > 0 else 0

    number_of_bits = 16
    up_scaling = 2 ** (number_of_bits - 1) - 2
    signal_scaled = waveform * up_scaling
    down_scaling = 2 ** (number_of_bits - 1)

    signal_scaled_dis = waveform
    if (density_function == "RPDF"):
        RPDF = waveform[random_channel][random_time] - 0.5

        signal_scaled_dis = signal_scaled + RPDF
    elif (density_function == "GPDF"):
        # TODO Replace by distribution code once
        # https://github.com/pytorch/pytorch/issues/29843 is resolved
        # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()

        num_rand_variables = 6

        gaussian = waveform[random_channel][random_time]
        for ws in num_rand_variables * [time_size]:
            rand_chan = int(torch.randint(channel_size, [1, ]).item())
            gaussian += waveform[rand_chan][int(torch.randint(ws, [1, ]).item())]

        signal_scaled_dis = signal_scaled + gaussian
    else:
        TPDF = torch.bartlett_window(time_size + 1)
        TPDF = TPDF.repeat((channel_size + 1), 1)
        signal_scaled_dis = signal_scaled + TPDF

    quantised_signal_scaled = torch.round(signal_scaled_dis)
    quantised_signal = quantised_signal_scaled / down_scaling
    return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])


def dither(waveform, density_function="TPDF", noise_shaping=False):
    # type: (Tensor, str, bool) -> Tensor
    r"""Dither increases the perceived dynamic range of audio stored at a
    particular bit-depth by eliminating nonlinear truncation distortion
    (i.e. adding minimally perceived noise to mask distortion caused by quantization).
    Args:
       waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
       density_function (string): The density function of a
           continuous random variable (Default: `TPDF`)
           Options: Triangular Probability Density Function - `TPDF`
                    Rectangular Probability Density Function - `RPDF`
                    Gaussian Probability Density Function - `GPDF`
       noise_shaping (boolean): a filtering process that shapes the spectral
           energy of quantisation error (Default: `False`)

    Returns:
       torch.Tensor: waveform dithered
    """
    dithered = _apply_probability_distribution(waveform, density_function=density_function)

    if noise_shaping:
        return _add_noise_shaping(dithered, waveform)
    else:
        return dithered


Vincent QB's avatar
Vincent QB committed
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
    # type: (Tensor, int, float, int) -> Tensor
    r"""
    Compute Normalized Cross-Correlation Function (NCCF).

    .. math::
        \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},

    where
    :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
    :math:`w` is the waveform,
    :math:`N` is the lenght of a frame,
    :math:`b_i` is the beginning of frame :math:`i`,
    :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
    """

    EPSILON = 10 ** (-9)

    # Number of lags to check
    lags = math.ceil(sample_rate / freq_low)

    frame_size = int(math.ceil(sample_rate * frame_time))

    waveform_length = waveform.size()[-1]
    num_of_frames = math.ceil(waveform_length / frame_size)

    p = lags + num_of_frames * frame_size - waveform_length
    waveform = torch.nn.functional.pad(waveform, (0, p))

    # Compute lags
    output_lag = []
    for lag in range(1, lags + 1):
        s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[
            ..., :num_of_frames, :
        ]
        s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[
            ..., :num_of_frames, :
        ]

        output_frames = (
            (s1 * s2).sum(-1)
            / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
            / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
        )

        output_lag.append(output_frames.unsqueeze(-1))

    nccf = torch.cat(output_lag, -1)

    return nccf


def _combine_max(a, b, thresh=0.99):
    # type: (Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], float) -> Tuple[Tensor, Tensor]
    """
    Take value from first if bigger than a multiplicative factor of the second, elementwise.
    """
    mask = (a[0] > thresh * b[0])
    values = mask * a[0] + ~mask * b[0]
    indices = mask * a[1] + ~mask * b[1]
    return values, indices


def _find_max_per_frame(nccf, sample_rate, freq_high):
    # type: (Tensor, int, int) -> Tensor
    r"""
    For each frame, take the highest value of NCCF,
    apply centered median smoothing, and convert to frequency.

    Note: If the max among all the lags is very close
    to the first half of lags, then the latter is taken.
    """

    lag_min = math.ceil(sample_rate / freq_high)

    # Find near enough max that is smallest

    best = torch.max(nccf[..., lag_min:], -1)

    half_size = nccf.shape[-1] // 2
    half = torch.max(nccf[..., lag_min:half_size], -1)

    best = _combine_max(half, best)
    indices = best[1]

    # Add back minimal lag
    indices += lag_min
    # Add 1 empirical calibration offset
    indices += 1

    return indices


def _median_smoothing(indices, win_length):
    # type: (Tensor, int) -> Tensor
    r"""
    Apply median smoothing to the 1D tensor over the given window.
    """

    # Centered windowed
    pad_length = (win_length - 1) // 2

    # "replicate" padding in any dimension
    indices = torch.nn.functional.pad(
        indices, (pad_length, 0), mode="constant", value=0.
    )

    indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
    roll = indices.unfold(-1, win_length, 1)

    values, _ = torch.median(roll, -1)
    return values


def detect_pitch_frequency(
    waveform,
    sample_rate,
    frame_time=10 ** (-2),
    win_length=30,
    freq_low=85,
    freq_high=3400,
):
    # type: (Tensor, int, float, int, int, int) -> Tensor
    r"""Detect pitch frequency.

    It is implemented using normalized cross-correlation function and median smoothing.

    Args:
Vincent QB's avatar
Vincent QB committed
1207
        waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
1208
1209
1210
1211
1212
1213
        sample_rate (int): The sample rate of the waveform (Hz)
        win_length (int): The window length for median smoothing (in number of frames)
        freq_low (int): Lowest frequency that can be detected (Hz)
        freq_high (int): Highest frequency that can be detected (Hz)

    Returns:
Vincent QB's avatar
Vincent QB committed
1214
        freq (torch.Tensor): Tensor of audio of dimension (..., frame)
Vincent QB's avatar
Vincent QB committed
1215
1216
    """

Vincent QB's avatar
Vincent QB committed
1217
1218
1219
    dim = waveform.dim()

    # pack batch
1220
    shape = list(waveform.size())
Vincent QB's avatar
Vincent QB committed
1221
1222
    waveform = waveform.reshape([-1] + shape[-1:])

Vincent QB's avatar
Vincent QB committed
1223
1224
1225
1226
1227
1228
1229
1230
    nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
    indices = _find_max_per_frame(nccf, sample_rate, freq_high)
    indices = _median_smoothing(indices, win_length)

    # Convert indices to frequency
    EPSILON = 10 ** (-9)
    freq = sample_rate / (EPSILON + indices.to(torch.float))

Vincent QB's avatar
Vincent QB committed
1231
    # unpack batch
1232
    freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
Vincent QB's avatar
Vincent QB committed
1233

Vincent QB's avatar
Vincent QB committed
1234
    return freq