functional.py 33 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import math
4
from typing import Optional, Tuple
moto's avatar
moto committed
5
import warnings
Vincent QB's avatar
Vincent QB committed
6

Jason Lian's avatar
Jason Lian committed
7
import torch
8
from torch import Tensor
Jason Lian's avatar
Jason Lian committed
9

Jason Lian's avatar
pre  
Jason Lian committed
10
__all__ = [
11
    "spectrogram",
12
    "griffinlim",
13
    "amplitude_to_DB",
14
15
    "DB_to_amplitude",
    "compute_deltas",
16
17
    "create_fb_matrix",
    "create_dct",
18
19
20
    "compute_deltas",
    "detect_pitch_frequency",
    "DB_to_amplitude",
21
22
23
24
25
26
    "mu_law_encoding",
    "mu_law_decoding",
    "complex_norm",
    "angle",
    "magphase",
    "phase_vocoder",
27
    'mask_along_axis',
wanglong001's avatar
wanglong001 committed
28
29
    'mask_along_axis_iid',
    'sliding_window_cmn',
30
    "spectral_centroid",
Jason Lian's avatar
pre  
Jason Lian committed
31
32
]

Vincent QB's avatar
Vincent QB committed
33

34
def spectrogram(
35
36
37
38
39
40
41
        waveform: Tensor,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: Optional[float],
42
43
44
45
        normalized: bool,
        center: bool = True,
        pad_mode: str = "reflect",
        onesided: bool = True
46
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
47
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
48
    The spectrogram can be either magnitude-only or complex.
jamarshon's avatar
jamarshon committed
49
50

    Args:
51
        waveform (Tensor): Tensor of audio of dimension (..., time)
jamarshon's avatar
jamarshon committed
52
        pad (int): Two sided padding of signal
53
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
54
        n_fft (int): Size of FFT
55
56
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
57
        power (float or None): Exponent for the magnitude spectrogram,
jamarshon's avatar
jamarshon committed
58
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
59
            If None, then the complex spectrum is returned instead.
60
        normalized (bool): Whether to normalize by magnitude after stft
61
62
63
64
65
66
67
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
jamarshon's avatar
jamarshon committed
68
69

    Returns:
70
        Tensor: Dimension (..., freq, time), freq is
Vincent QB's avatar
Vincent QB committed
71
        ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
72
        Fourier bins, and time is the number of window hops (n_frame).
Jason Lian's avatar
Jason Lian committed
73
    """
Jason Lian's avatar
Jason Lian committed
74
75

    if pad > 0:
76
        # TODO add "with torch.no_grad():" back when JIT supports it
77
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
78

79
80
    # pack batch
    shape = waveform.size()
81
    waveform = waveform.reshape(-1, shape[-1])
82

Jason Lian's avatar
Jason Lian committed
83
    # default values are consistent with librosa.core.spectrum._spectrogram
84
85
86
87
88
89
    spec_f = torch.stft(
        input=waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
90
91
        center=center,
        pad_mode=pad_mode,
92
        normalized=False,
93
        onesided=onesided,
94
        return_complex=True,
95
    )
96

97
    # unpack batch
98
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
99

100
    if normalized:
101
        spec_f /= window.pow(2.).sum().sqrt()
102
    if power is not None:
103
104
105
106
        if power == 1.0:
            return spec_f.abs()
        return spec_f.abs().pow(power)
    return torch.view_as_real(spec_f)
Jason Lian's avatar
more  
Jason Lian committed
107
108


109
def griffinlim(
110
111
112
113
114
115
116
117
118
119
120
121
        specgram: Tensor,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: float,
        normalized: bool,
        n_iter: int,
        momentum: float,
        length: Optional[int],
        rand_init: bool
) -> Tensor:
122
123
124
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

125
    *  [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
126
127
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.
128
    *  [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
129
130
131
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.
132
    *  [3] D. W. Griffin and J. S. Lim,
133
134
135
136
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
137
        specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
138
            where freq is ``n_fft // 2 + 1``.
139
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
140
141
142
143
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
144
        power (float): Exponent for the magnitude spectrogram,
145
146
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
        normalized (bool): Whether to normalize by magnitude after stft.
147
148
149
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
150
151
152
            Values near 1 can lead to faster convergence, but above 1 may not converge.
        length (int or None): Array length of the expected output.
        rand_init (bool): Initializes phase randomly if True, to zero otherwise.
153
154

    Returns:
Vincent QB's avatar
Vincent QB committed
155
        torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
156
    """
157
158
    assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
    assert momentum >= 0, 'momentum={} < 0'.format(momentum)
159

160
161
162
163
164
165
    if normalized:
        warnings.warn(
            "The argument normalized is not used in Griffin-Lim, "
            "and will be removed in v0.9.0 release. To suppress this warning, "
            "please use `normalized=False`.")

Vincent QB's avatar
Vincent QB committed
166
167
    # pack batch
    shape = specgram.size()
168
    specgram = specgram.reshape([-1] + list(shape[-2:]))
Vincent QB's avatar
Vincent QB committed
169

Vincent QB's avatar
Vincent QB committed
170
    specgram = specgram.pow(1 / power)
171
172

    # randomly initialize the phase
Vincent QB's avatar
Vincent QB committed
173
    batch, freq, frames = specgram.size()
174
175
176
177
178
    if rand_init:
        angles = 2 * math.pi * torch.rand(batch, freq, frames)
    else:
        angles = torch.zeros(batch, freq, frames)
    angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
179
        .to(dtype=specgram.dtype, device=specgram.device)
Vincent QB's avatar
Vincent QB committed
180
    specgram = specgram.unsqueeze(-1).expand_as(angles)
181
182
183
184
185
186
187
188
189

    # And initialize the previous iterate to 0
    rebuilt = torch.tensor(0.)

    for _ in range(n_iter):
        # Store the previous iterate
        tprev = rebuilt

        # Invert with our current estimate of the phases
Jeremy Chen's avatar
Jeremy Chen committed
190
191
192
193
194
195
        inverse = torch.istft(specgram * angles,
                              n_fft=n_fft,
                              hop_length=hop_length,
                              win_length=win_length,
                              window=window,
                              length=length).float()
196
197

        # Rebuild the spectrogram
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        rebuilt = torch.view_as_real(
            torch.stft(
                input=inverse,
                n_fft=n_fft,
                hop_length=hop_length,
                win_length=win_length,
                window=window,
                center=True,
                pad_mode='reflect',
                normalized=False,
                onesided=True,
                return_complex=True,
            )
        )
212
213

        # Update our phase estimates
214
215
216
        angles = rebuilt
        if momentum:
            angles = angles - tprev.mul_(momentum / (1 + momentum))
217
        angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles))
218
219

    # Return the final phase estimates
Jeremy Chen's avatar
Jeremy Chen committed
220
221
222
223
224
225
    waveform = torch.istft(specgram * angles,
                           n_fft=n_fft,
                           hop_length=hop_length,
                           win_length=win_length,
                           window=window,
                           length=length)
Vincent QB's avatar
Vincent QB committed
226
227

    # unpack batch
228
    waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
Vincent QB's avatar
Vincent QB committed
229
230

    return waveform
231
232


233
234
235
236
237
238
239
def amplitude_to_DB(
        x: Tensor,
        multiplier: float,
        amin: float,
        db_multiplier: float,
        top_db: Optional[float] = None
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
240
    r"""Turn a tensor from the power/amplitude scale to the decibel scale.
241

242
    This output depends on the maximum value in the input tensor, and so
243
    may return different values for an audio clip split into snippets vs. a
244
    full clip.
245
246

    Args:
247
        x (Tensor): Input tensor before being converted to decibel scale
248
        multiplier (float): Use 10. for power and 20. for amplitude
249
        amin (float): Number to clamp ``x``
250
        db_multiplier (float): Log10(max(reference value and amin))
251
        top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
252
            is 80. (Default: ``None``)
253
254

    Returns:
255
        Tensor: Output tensor in decibel scale
256
    """
257
258
    x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
    x_db -= multiplier * db_multiplier
259
260

    if top_db is not None:
261
        x_db = x_db.clamp(min=x_db.max().item() - top_db)
262

263
    return x_db
264
265


266
267
268
269
270
def DB_to_amplitude(
        x: Tensor,
        ref: float,
        power: float
) -> Tensor:
271
272
273
    r"""Turn a tensor from the decibel scale to the power/amplitude scale.

    Args:
274
        x (Tensor): Input tensor before being converted to power/amplitude scale.
275
276
277
278
        ref (float): Reference which the output will be scaled by.
        power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.

    Returns:
279
        Tensor: Output tensor in power/amplitude scale.
280
281
282
283
    """
    return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)


284
285
286
287
288
def create_fb_matrix(
        n_freqs: int,
        f_min: float,
        f_max: float,
        n_mels: int,
Vincent QB's avatar
Vincent QB committed
289
        sample_rate: int,
290
        norm: Optional[str] = None
291
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
292
    r"""Create a frequency bin conversion matrix.
Jason Lian's avatar
more  
Jason Lian committed
293

jamarshon's avatar
jamarshon committed
294
    Args:
295
        n_freqs (int): Number of frequencies to highlight/apply
engineerchuan's avatar
engineerchuan committed
296
297
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
298
        n_mels (int): Number of mel filterbanks
engineerchuan's avatar
engineerchuan committed
299
        sample_rate (int): Sample rate of the audio waveform
300
301
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
Jason Lian's avatar
more  
Jason Lian committed
302

jamarshon's avatar
jamarshon committed
303
    Returns:
304
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
305
306
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
307
308
        size (..., ``n_freqs``), the applied result would be
        ``A * create_fb_matrix(A.size(-1), ...)``.
309
    """
310
311
312
313

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

314
    # freq bins
engineerchuan's avatar
engineerchuan committed
315
316
317
    # Equivalent filterbank construction by Librosa
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

Jason Lian's avatar
more  
Jason Lian committed
318
    # calculate mel freq bins
319
    # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
engineerchuan's avatar
engineerchuan committed
320
    m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
321
    m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
Jason Lian's avatar
more  
Jason Lian committed
322
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
323
    # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
324
    f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
Jason Lian's avatar
more  
Jason Lian committed
325
326
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
engineerchuan's avatar
engineerchuan committed
327
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_mels + 2)
Jason Lian's avatar
more  
Jason Lian committed
328
    # create overlapping triangles
329
    zero = torch.zeros(1)
330
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_mels)
331
332
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_mels)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))
Vincent QB's avatar
Vincent QB committed
333

334
    if norm is not None and norm == "slaney":
Vincent QB's avatar
Vincent QB committed
335
336
337
338
        # Slaney-style mel is scaled to be approx constant energy per channel
        enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
        fb *= enorm.unsqueeze(0)

339
340
341
342
343
344
345
    if (fb.max(dim=0).values == 0.).any():
        warnings.warn(
            "At least one mel filterbank has all zero values. "
            f"The value for `n_mels` ({n_mels}) may be set too high. "
            f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
        )

Jason Lian's avatar
more  
Jason Lian committed
346
347
348
    return fb


349
350
351
352
353
def create_dct(
        n_mfcc: int,
        n_mels: int,
        norm: Optional[str]
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
354
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
jamarshon's avatar
jamarshon committed
355
    normalized depending on norm.
Jason Lian's avatar
Jason Lian committed
356

jamarshon's avatar
jamarshon committed
357
    Args:
358
359
        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
360
        norm (str or None): Norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
361

jamarshon's avatar
jamarshon committed
362
    Returns:
363
        Tensor: The transformation matrix, to be right-multiplied to
364
        row-wise data of size (``n_mels``, ``n_mfcc``).
Jason Lian's avatar
more  
Jason Lian committed
365
366
    """
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
367
368
369
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
370
371
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
372
    else:
373
        assert norm == "ortho"
374
        dct[0] *= 1.0 / math.sqrt(2.0)
375
        dct *= math.sqrt(2.0 / float(n_mels))
376
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
377
378


379
380
381
382
def mu_law_encoding(
        x: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
383
    r"""Encode signal based on mu-law companding.  For more info see the
Jason Lian's avatar
Jason Lian committed
384
385
386
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
jamarshon's avatar
jamarshon committed
387
    returns a signal encoded with values from 0 to quantization_channels - 1.
Jason Lian's avatar
Jason Lian committed
388

jamarshon's avatar
jamarshon committed
389
    Args:
390
        x (Tensor): Input tensor
391
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
392

jamarshon's avatar
jamarshon committed
393
    Returns:
394
        Tensor: Input after mu-law encoding
Jason Lian's avatar
Jason Lian committed
395
    """
396
    mu = quantization_channels - 1.0
397
    if not x.is_floating_point():
398
399
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
400
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
401
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
402
403
404
    return x_mu


405
406
407
408
def mu_law_decoding(
        x_mu: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
409
    r"""Decode mu-law encoded signal.  For more info see the
Jason Lian's avatar
Jason Lian committed
410
411
412
413
414
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

jamarshon's avatar
jamarshon committed
415
    Args:
416
        x_mu (Tensor): Input tensor
417
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
418

jamarshon's avatar
jamarshon committed
419
    Returns:
420
        Tensor: Input after mu-law decoding
Jason Lian's avatar
Jason Lian committed
421
    """
422
    mu = quantization_channels - 1.0
423
    if not x_mu.is_floating_point():
424
425
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
426
427
    x = ((x_mu) / mu) * 2 - 1.0
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
Jason Lian's avatar
more  
Jason Lian committed
428
    return x
429
430


431
432
433
434
def complex_norm(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tensor:
435
    r"""Compute the norm of complex tensor input.
436
437

    Args:
438
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
439
        power (float): Power of the norm. (Default: `1.0`).
440
441

    Returns:
442
        Tensor: Power of the normed input tensor. Shape of `(..., )`
443
    """
444
445
446
447

    # Replace by torch.norm once issue is fixed
    # https://github.com/pytorch/pytorch/issues/34279
    return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
448
449


450
451
452
def angle(
        complex_tensor: Tensor
) -> Tensor:
453
454
455
    r"""Compute the angle of complex tensor input.

    Args:
456
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
457
458

    Return:
459
        Tensor: Angle of a complex tensor. Shape of `(..., )`
460
461
462
463
    """
    return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])


464
465
466
467
def magphase(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tuple[Tensor, Tensor]:
468
    r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
469
470

    Args:
471
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
472
473
474
        power (float): Power of the norm. (Default: `1.0`)

    Returns:
475
        (Tensor, Tensor): The magnitude and phase of the complex tensor
476
477
478
479
480
481
    """
    mag = complex_norm(complex_tensor, power)
    phase = angle(complex_tensor)
    return mag, phase


482
483
484
485
486
def phase_vocoder(
        complex_specgrams: Tensor,
        rate: float,
        phase_advance: Tensor
) -> Tensor:
487
    r"""Given a STFT tensor, speed up in time without modifying pitch by a
488
    factor of ``rate``.
Vincent QB's avatar
Vincent QB committed
489

490
    Args:
491
        complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)`
492
        rate (float): Speed-up factor
493
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Vincent QB's avatar
Vincent QB committed
494

495
    Returns:
496
        Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)`
Vincent QB's avatar
Vincent QB committed
497

498
    Example
499
500
501
502
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time, complex=2)
        >>> complex_specgrams = torch.randn(2, freq, 300, 2)
        >>> rate = 1.3 # Speed up by 30%
503
        >>> phase_advance = torch.linspace(
504
        >>>    0, math.pi * hop_length, freq)[..., None]
505
506
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
507
        torch.Size([2, 1025, 231, 2])
508
    """
509

510
511
    # pack batch
    shape = complex_specgrams.size()
512
    complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
513

514
515
516
517
518
    time_steps = torch.arange(0,
                              complex_specgrams.size(-2),
                              rate,
                              device=complex_specgrams.device,
                              dtype=complex_specgrams.dtype)
519

520
    alphas = time_steps % 1.0
Vincent QB's avatar
Vincent QB committed
521
    phase_0 = angle(complex_specgrams[..., :1, :])
522
523
524
525

    # Time Padding
    complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2])

526
    # (new_bins, freq, 2)
Vincent QB's avatar
Vincent QB committed
527
528
    complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long())
    complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long())
529
530
531
532

    angle_0 = angle(complex_specgrams_0)
    angle_1 = angle(complex_specgrams_1)

533
534
    norm_0 = torch.norm(complex_specgrams_0, p=2, dim=-1)
    norm_1 = torch.norm(complex_specgrams_1, p=2, dim=-1)
535
536
537
538
539
540

    phase = angle_1 - angle_0 - phase_advance
    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

    # Compute Phase Accum
    phase = phase + phase_advance
Vincent QB's avatar
Vincent QB committed
541
    phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
542
543
544
545
546
547
548
549
550
    phase_acc = torch.cumsum(phase, -1)

    mag = alphas * norm_1 + (1 - alphas) * norm_0

    real_stretch = mag * torch.cos(phase_acc)
    imag_stretch = mag * torch.sin(phase_acc)

    complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)

551
    # unpack batch
552
    complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
553

554
    return complex_specgrams_stretch
555
556


557
558
559
560
561
562
def mask_along_axis_iid(
        specgrams: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
563
564
565
566
567
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.

    Args:
Vincent QB's avatar
Vincent QB committed
568
        specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
569
570
571
572
573
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)

    Returns:
574
        Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
575
576
577
578
579
    """

    if axis != 2 and axis != 3:
        raise ValueError('Only Frequency and Time masking are supported')

580
581
582
583
584
    device = specgrams.device
    dtype = specgrams.dtype

    value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
    min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
585
586

    # Create broadcastable mask
587
588
589
    mask_start = min_value[..., None, None]
    mask_end = (min_value + value)[..., None, None]
    mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
590
591
592
593
594
595
596
597
598

    # Per batch example masking
    specgrams = specgrams.transpose(axis, -1)
    specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
    specgrams = specgrams.transpose(axis, -1)

    return specgrams


599
600
601
602
603
604
def mask_along_axis(
        specgram: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
605
606
607
608
609
610
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
    All examples will have the same mask interval.

    Args:
Vincent QB's avatar
Vincent QB committed
611
        specgram (Tensor): Real spectrogram (channel, freq, time)
612
613
614
615
616
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

    Returns:
617
        Tensor: Masked spectrogram of dimensions (channel, freq, time)
618
619
    """

620
621
    # pack batch
    shape = specgram.size()
622
    specgram = specgram.reshape([-1] + list(shape[-2:]))
623

624
625
626
627
628
629
630
631
632
633
634
635
636
637
    value = torch.rand(1) * mask_param
    min_value = torch.rand(1) * (specgram.size(axis) - value)

    mask_start = (min_value.long()).squeeze()
    mask_end = (min_value.long() + value.long()).squeeze()

    assert mask_end - mask_start < mask_param
    if axis == 1:
        specgram[:, mask_start:mask_end] = mask_value
    elif axis == 2:
        specgram[:, :, mask_start:mask_end] = mask_value
    else:
        raise ValueError('Only Frequency and Time masking are supported')

638
    # unpack batch
639
    specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
640

Vincent QB's avatar
Vincent QB committed
641
    return specgram
642
643


644
645
646
647
648
def compute_deltas(
        specgram: Tensor,
        win_length: int = 5,
        mode: str = "replicate"
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
649
650
651
    r"""Compute delta coefficients of a tensor, usually a spectrogram:

    .. math::
652
       d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
Vincent QB's avatar
Vincent QB committed
653
654
655

    where :math:`d_t` is the deltas at time :math:`t`,
    :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
656
    :math:`N` is ``(win_length-1)//2``.
Vincent QB's avatar
Vincent QB committed
657
658

    Args:
659
660
661
        specgram (Tensor): Tensor of audio of dimension (..., freq, time)
        win_length (int, optional): The window length used for computing delta (Default: ``5``)
        mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Vincent QB's avatar
Vincent QB committed
662
663

    Returns:
664
        Tensor: Tensor of deltas of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
665
666
667
668
669
670

    Example
        >>> specgram = torch.randn(1, 40, 1000)
        >>> delta = compute_deltas(specgram)
        >>> delta2 = compute_deltas(delta)
    """
671
672
    device = specgram.device
    dtype = specgram.dtype
Vincent QB's avatar
Vincent QB committed
673

Vincent QB's avatar
Vincent QB committed
674
675
    # pack batch
    shape = specgram.size()
676
    specgram = specgram.reshape(1, -1, shape[-1])
Vincent QB's avatar
Vincent QB committed
677

Vincent QB's avatar
Vincent QB committed
678
679
680
681
682
683
684
685
686
    assert win_length >= 3

    n = (win_length - 1) // 2

    # twice sum of integer squared
    denom = n * (n + 1) * (2 * n + 1) / 3

    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)

687
    kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
Vincent QB's avatar
Vincent QB committed
688

Vincent QB's avatar
Vincent QB committed
689
690
691
    output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom

    # unpack batch
692
    output = output.reshape(shape)
Vincent QB's avatar
Vincent QB committed
693
694

    return output
Vincent QB's avatar
Vincent QB committed
695
696


697
698
699
700
701
702
def _compute_nccf(
        waveform: Tensor,
        sample_rate: int,
        frame_time: float,
        freq_low: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
703
704
705
706
707
708
709
710
711
    r"""
    Compute Normalized Cross-Correlation Function (NCCF).

    .. math::
        \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},

    where
    :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
    :math:`w` is the waveform,
712
    :math:`N` is the length of a frame,
Vincent QB's avatar
Vincent QB committed
713
714
715
716
717
718
719
    :math:`b_i` is the beginning of frame :math:`i`,
    :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
    """

    EPSILON = 10 ** (-9)

    # Number of lags to check
720
    lags = int(math.ceil(sample_rate / freq_low))
Vincent QB's avatar
Vincent QB committed
721
722
723
724

    frame_size = int(math.ceil(sample_rate * frame_time))

    waveform_length = waveform.size()[-1]
725
    num_of_frames = int(math.ceil(waveform_length / frame_size))
Vincent QB's avatar
Vincent QB committed
726
727
728
729
730
731
732

    p = lags + num_of_frames * frame_size - waveform_length
    waveform = torch.nn.functional.pad(waveform, (0, p))

    # Compute lags
    output_lag = []
    for lag in range(1, lags + 1):
733
734
        s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
        s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
Vincent QB's avatar
Vincent QB committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748

        output_frames = (
            (s1 * s2).sum(-1)
            / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
            / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
        )

        output_lag.append(output_frames.unsqueeze(-1))

    nccf = torch.cat(output_lag, -1)

    return nccf


749
750
751
752
753
def _combine_max(
        a: Tuple[Tensor, Tensor],
        b: Tuple[Tensor, Tensor],
        thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
Vincent QB's avatar
Vincent QB committed
754
755
756
757
758
759
760
761
762
    """
    Take value from first if bigger than a multiplicative factor of the second, elementwise.
    """
    mask = (a[0] > thresh * b[0])
    values = mask * a[0] + ~mask * b[0]
    indices = mask * a[1] + ~mask * b[1]
    return values, indices


763
764
765
766
767
def _find_max_per_frame(
        nccf: Tensor,
        sample_rate: int,
        freq_high: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
768
769
770
771
772
773
774
775
    r"""
    For each frame, take the highest value of NCCF,
    apply centered median smoothing, and convert to frequency.

    Note: If the max among all the lags is very close
    to the first half of lags, then the latter is taken.
    """

776
    lag_min = int(math.ceil(sample_rate / freq_high))
Vincent QB's avatar
Vincent QB committed
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795

    # Find near enough max that is smallest

    best = torch.max(nccf[..., lag_min:], -1)

    half_size = nccf.shape[-1] // 2
    half = torch.max(nccf[..., lag_min:half_size], -1)

    best = _combine_max(half, best)
    indices = best[1]

    # Add back minimal lag
    indices += lag_min
    # Add 1 empirical calibration offset
    indices += 1

    return indices


796
797
798
799
def _median_smoothing(
        indices: Tensor,
        win_length: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
    r"""
    Apply median smoothing to the 1D tensor over the given window.
    """

    # Centered windowed
    pad_length = (win_length - 1) // 2

    # "replicate" padding in any dimension
    indices = torch.nn.functional.pad(
        indices, (pad_length, 0), mode="constant", value=0.
    )

    indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
    roll = indices.unfold(-1, win_length, 1)

    values, _ = torch.median(roll, -1)
    return values


def detect_pitch_frequency(
820
821
822
823
824
825
826
        waveform: Tensor,
        sample_rate: int,
        frame_time: float = 10 ** (-2),
        win_length: int = 30,
        freq_low: int = 85,
        freq_high: int = 3400,
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
827
828
829
830
831
    r"""Detect pitch frequency.

    It is implemented using normalized cross-correlation function and median smoothing.

    Args:
832
        waveform (Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
833
        sample_rate (int): The sample rate of the waveform (Hz)
834
835
836
837
        frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
        win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
        freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
        freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Vincent QB's avatar
Vincent QB committed
838
839

    Returns:
840
        Tensor: Tensor of freq of dimension (..., frame)
Vincent QB's avatar
Vincent QB committed
841
    """
Vincent QB's avatar
Vincent QB committed
842
    # pack batch
843
    shape = list(waveform.size())
844
    waveform = waveform.reshape([-1] + shape[-1:])
Vincent QB's avatar
Vincent QB committed
845

Vincent QB's avatar
Vincent QB committed
846
847
848
849
850
851
852
853
    nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
    indices = _find_max_per_frame(nccf, sample_rate, freq_high)
    indices = _median_smoothing(indices, win_length)

    # Convert indices to frequency
    EPSILON = 10 ** (-9)
    freq = sample_rate / (EPSILON + indices.to(torch.float))

Vincent QB's avatar
Vincent QB committed
854
    # unpack batch
855
    freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
Vincent QB's avatar
Vincent QB committed
856

Vincent QB's avatar
Vincent QB committed
857
    return freq
wanglong001's avatar
wanglong001 committed
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881


def sliding_window_cmn(
    waveform: Tensor,
    cmn_window: int = 600,
    min_cmn_window: int = 100,
    center: bool = False,
    norm_vars: bool = False,
) -> Tensor:
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
        waveform (Tensor): Tensor of audio of dimension (..., freq, time)
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

    Returns:
        Tensor: Tensor of freq of dimension (..., frame)
    """
882
883
884
885
886
    input_shape = waveform.shape
    num_frames, num_feats = input_shape[-2:]
    waveform = waveform.view(-1, num_frames, num_feats)
    num_channels = waveform.shape[0]

wanglong001's avatar
wanglong001 committed
887
888
889
    dtype = waveform.dtype
    device = waveform.device
    last_window_start = last_window_end = -1
890
891
    cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
    cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
892
    cmn_waveform = torch.zeros(
893
        num_channels, num_frames, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
    for t in range(num_frames):
        window_start = 0
        window_end = 0
        if center:
            window_start = t - cmn_window // 2
            window_end = window_start + cmn_window
        else:
            window_start = t - cmn_window
            window_end = t + 1
        if window_start < 0:
            window_end -= window_start
            window_start = 0
        if not center:
            if window_end > t:
                window_end = max(t + 1, min_cmn_window)
        if window_end > num_frames:
            window_start -= (window_end - num_frames)
            window_end = num_frames
            if window_start < 0:
                window_start = 0
        if last_window_start == -1:
915
916
            input_part = waveform[:, window_start: window_end - window_start, :]
            cur_sum += torch.sum(input_part, 1)
wanglong001's avatar
wanglong001 committed
917
            if norm_vars:
918
                cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
wanglong001's avatar
wanglong001 committed
919
920
        else:
            if window_start > last_window_start:
921
                frame_to_remove = waveform[:, last_window_start, :]
wanglong001's avatar
wanglong001 committed
922
923
924
925
                cur_sum -= frame_to_remove
                if norm_vars:
                    cur_sumsq -= (frame_to_remove ** 2)
            if window_end > last_window_end:
926
                frame_to_add = waveform[:, last_window_end, :]
wanglong001's avatar
wanglong001 committed
927
928
929
930
931
932
                cur_sum += frame_to_add
                if norm_vars:
                    cur_sumsq += (frame_to_add ** 2)
        window_frames = window_end - window_start
        last_window_start = window_start
        last_window_end = window_end
933
        cmn_waveform[:, t, :] = waveform[:, t, :] - cur_sum / window_frames
wanglong001's avatar
wanglong001 committed
934
935
        if norm_vars:
            if window_frames == 1:
936
937
                cmn_waveform[:, t, :] = torch.zeros(
                    num_channels, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
938
939
940
941
942
            else:
                variance = cur_sumsq
                variance = variance / window_frames
                variance -= ((cur_sum ** 2) / (window_frames ** 2))
                variance = torch.pow(variance, -0.5)
943
944
945
946
947
                cmn_waveform[:, t, :] *= variance

    cmn_waveform = cmn_waveform.view(input_shape[:-2] + (num_frames, num_feats))
    if len(input_shape) == 2:
        cmn_waveform = cmn_waveform.squeeze(0)
wanglong001's avatar
wanglong001 committed
948
    return cmn_waveform
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983


def spectral_centroid(
        waveform: Tensor,
        sample_rate: int,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
) -> Tensor:
    r"""
    Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        waveform (Tensor): Tensor of audio of dimension (..., time)
        sample_rate (int): Sample rate of the audio waveform
        pad (int): Two sided padding of signal
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size

    Returns:
        Tensor: Dimension (..., time)
    """
    specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
                           win_length=win_length, power=1., normalized=False)
    freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
                           device=specgram.device).reshape((-1, 1))
    freq_dim = -2
    return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)