functional.py 48.7 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import io
4
import math
moto's avatar
moto committed
5
import warnings
6
from typing import Optional, Tuple
Vincent QB's avatar
Vincent QB committed
7

Jason Lian's avatar
Jason Lian committed
8
import torch
9
from torch import Tensor
Caroline Chen's avatar
Caroline Chen committed
10
from torchaudio._internal import module_utils as _mod_utils
11
import torchaudio
Jason Lian's avatar
Jason Lian committed
12

Jason Lian's avatar
pre  
Jason Lian committed
13
__all__ = [
14
    "spectrogram",
15
    "griffinlim",
16
    "amplitude_to_DB",
17
18
    "DB_to_amplitude",
    "compute_deltas",
moto's avatar
moto committed
19
    "compute_kaldi_pitch",
20
21
    "create_fb_matrix",
    "create_dct",
22
23
24
    "compute_deltas",
    "detect_pitch_frequency",
    "DB_to_amplitude",
25
26
27
28
29
30
    "mu_law_encoding",
    "mu_law_decoding",
    "complex_norm",
    "angle",
    "magphase",
    "phase_vocoder",
31
    'mask_along_axis',
wanglong001's avatar
wanglong001 committed
32
33
    'mask_along_axis_iid',
    'sliding_window_cmn',
34
    "spectral_centroid",
35
    "apply_codec",
36
    "resample",
Jason Lian's avatar
pre  
Jason Lian committed
37
38
]

Vincent QB's avatar
Vincent QB committed
39

40
def spectrogram(
41
42
43
44
45
46
47
        waveform: Tensor,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: Optional[float],
48
49
50
        normalized: bool,
        center: bool = True,
        pad_mode: str = "reflect",
51
52
        onesided: bool = True,
        return_complex: bool = False,
53
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
54
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
55
    The spectrogram can be either magnitude-only or complex.
jamarshon's avatar
jamarshon committed
56
57

    Args:
58
        waveform (Tensor): Tensor of audio of dimension (..., time)
jamarshon's avatar
jamarshon committed
59
        pad (int): Two sided padding of signal
60
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
61
        n_fft (int): Size of FFT
62
63
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
64
        power (float or None): Exponent for the magnitude spectrogram,
jamarshon's avatar
jamarshon committed
65
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
66
            If None, then the complex spectrum is returned instead.
67
        normalized (bool): Whether to normalize by magnitude after stft
68
69
70
71
72
73
74
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
75
76
77
78
79
80
        return_complex (bool, optional):
            ``return_complex = True``, this function returns the resulting Tensor in
            complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
            dimension for real and imaginary parts. (see ``torch.view_as_real``).
            When ``power`` is provided, the value must be False, as the resulting
            Tensor represents real-valued power.
jamarshon's avatar
jamarshon committed
81
82

    Returns:
83
        Tensor: Dimension (..., freq, time), freq is
Vincent QB's avatar
Vincent QB committed
84
        ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
85
        Fourier bins, and time is the number of window hops (n_frame).
Jason Lian's avatar
Jason Lian committed
86
    """
87
88
89
90
    if power is not None and return_complex:
        raise ValueError(
            'When `power` is provided, the return value is real-valued. '
            'Therefore, `return_complex` must be False.')
Jason Lian's avatar
Jason Lian committed
91
92

    if pad > 0:
93
        # TODO add "with torch.no_grad():" back when JIT supports it
94
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
95

96
97
    # pack batch
    shape = waveform.size()
98
    waveform = waveform.reshape(-1, shape[-1])
99

Jason Lian's avatar
Jason Lian committed
100
    # default values are consistent with librosa.core.spectrum._spectrogram
101
102
103
104
105
106
    spec_f = torch.stft(
        input=waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
107
108
        center=center,
        pad_mode=pad_mode,
109
        normalized=False,
110
        onesided=onesided,
111
        return_complex=True,
112
    )
113

114
    # unpack batch
115
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
116

117
    if normalized:
118
        spec_f /= window.pow(2.).sum().sqrt()
119
    if power is not None:
120
121
122
        if power == 1.0:
            return spec_f.abs()
        return spec_f.abs().pow(power)
123
124
125
    if not return_complex:
        return torch.view_as_real(spec_f)
    return spec_f
Jason Lian's avatar
more  
Jason Lian committed
126
127


128
def griffinlim(
129
130
131
132
133
134
135
136
137
138
139
        specgram: Tensor,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: float,
        n_iter: int,
        momentum: float,
        length: Optional[int],
        rand_init: bool
) -> Tensor:
140
141
142
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

143
    *  [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
144
145
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.
146
    *  [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
147
148
149
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.
150
    *  [3] D. W. Griffin and J. S. Lim,
151
152
153
154
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
155
        specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
156
            where freq is ``n_fft // 2 + 1``.
157
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
158
159
160
161
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
162
        power (float): Exponent for the magnitude spectrogram,
163
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
164
165
166
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
167
168
169
            Values near 1 can lead to faster convergence, but above 1 may not converge.
        length (int or None): Array length of the expected output.
        rand_init (bool): Initializes phase randomly if True, to zero otherwise.
170
171

    Returns:
Vincent QB's avatar
Vincent QB committed
172
        torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
173
    """
174
175
    assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
    assert momentum >= 0, 'momentum={} < 0'.format(momentum)
176

Vincent QB's avatar
Vincent QB committed
177
178
    # pack batch
    shape = specgram.size()
179
    specgram = specgram.reshape([-1] + list(shape[-2:]))
Vincent QB's avatar
Vincent QB committed
180

Vincent QB's avatar
Vincent QB committed
181
    specgram = specgram.pow(1 / power)
182
183

    # randomly initialize the phase
Vincent QB's avatar
Vincent QB committed
184
    batch, freq, frames = specgram.size()
185
186
187
188
189
    if rand_init:
        angles = 2 * math.pi * torch.rand(batch, freq, frames)
    else:
        angles = torch.zeros(batch, freq, frames)
    angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
190
        .to(dtype=specgram.dtype, device=specgram.device)
Vincent QB's avatar
Vincent QB committed
191
    specgram = specgram.unsqueeze(-1).expand_as(angles)
192
193
194
195
196
197
198
199
200

    # And initialize the previous iterate to 0
    rebuilt = torch.tensor(0.)

    for _ in range(n_iter):
        # Store the previous iterate
        tprev = rebuilt

        # Invert with our current estimate of the phases
Jeremy Chen's avatar
Jeremy Chen committed
201
202
203
204
205
        inverse = torch.istft(specgram * angles,
                              n_fft=n_fft,
                              hop_length=hop_length,
                              win_length=win_length,
                              window=window,
206
                              length=length)
207
208

        # Rebuild the spectrogram
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        rebuilt = torch.view_as_real(
            torch.stft(
                input=inverse,
                n_fft=n_fft,
                hop_length=hop_length,
                win_length=win_length,
                window=window,
                center=True,
                pad_mode='reflect',
                normalized=False,
                onesided=True,
                return_complex=True,
            )
        )
223
224

        # Update our phase estimates
225
226
227
        angles = rebuilt
        if momentum:
            angles = angles - tprev.mul_(momentum / (1 + momentum))
228
        angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles))
229
230

    # Return the final phase estimates
Jeremy Chen's avatar
Jeremy Chen committed
231
232
233
234
235
236
    waveform = torch.istft(specgram * angles,
                           n_fft=n_fft,
                           hop_length=hop_length,
                           win_length=win_length,
                           window=window,
                           length=length)
Vincent QB's avatar
Vincent QB committed
237
238

    # unpack batch
239
    waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
Vincent QB's avatar
Vincent QB committed
240
241

    return waveform
242
243


244
245
246
247
248
249
250
def amplitude_to_DB(
        x: Tensor,
        multiplier: float,
        amin: float,
        db_multiplier: float,
        top_db: Optional[float] = None
) -> Tensor:
251
    r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
252

253
254
    The output of each tensor in a batch depends on the maximum value of that tensor,
    and so may return different values for an audio clip split into snippets vs. a full clip.
255
256

    Args:
257
258
259
260

        x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
          the form `(..., freq, time)`. Batched inputs should include a channel dimension and
          have the form `(batch, channel, freq, time)`.
261
        multiplier (float): Use 10. for power and 20. for amplitude
262
        amin (float): Number to clamp ``x``
263
        db_multiplier (float): Log10(max(reference value and amin))
264
        top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
265
            is 80. (Default: ``None``)
266
267

    Returns:
268
        Tensor: Output tensor in decibel scale
269
    """
270
271
    x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
    x_db -= multiplier * db_multiplier
272
273

    if top_db is not None:
274
275
276
277
278
279
280
281
282
        # Expand batch
        shape = x_db.size()
        packed_channels = shape[-3] if x_db.dim() > 2 else 1
        x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])

        x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))

        # Repack batch
        x_db = x_db.reshape(shape)
283

284
    return x_db
285
286


287
288
289
290
291
def DB_to_amplitude(
        x: Tensor,
        ref: float,
        power: float
) -> Tensor:
292
293
294
    r"""Turn a tensor from the decibel scale to the power/amplitude scale.

    Args:
295
        x (Tensor): Input tensor before being converted to power/amplitude scale.
296
297
298
299
        ref (float): Reference which the output will be scaled by.
        power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.

    Returns:
300
        Tensor: Output tensor in power/amplitude scale.
301
302
303
304
    """
    return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)


305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
    r"""Convert Hz to Mels.

    Args:
        freqs (float): Frequencies in Hz
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        mels (float): Frequency in Mels
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 2595.0 * math.log10(1.0 + (freq / 700.0))

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    mels = (freq - f_min) / f_sp

    # Fill in the log-scale part
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    if freq >= min_log_hz:
        mels = min_log_mel + math.log(freq / min_log_hz) / logstep

    return mels


def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
    """Convert mel bin numbers to frequencies.

    Args:
        mels (Tensor): Mel frequencies
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        freqs (Tensor): Mels converted in Hz
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 700.0 * (10.0**(mels / 2595.0) - 1.0)

    # Fill in the linear scale
    f_min = 0.0
    f_sp = 200.0 / 3
    freqs = f_min + f_sp * mels

    # And now the nonlinear scale
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    log_t = (mels >= min_log_mel)
    freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

    return freqs


372
373
374
375
376
def create_fb_matrix(
        n_freqs: int,
        f_min: float,
        f_max: float,
        n_mels: int,
Vincent QB's avatar
Vincent QB committed
377
        sample_rate: int,
378
379
        norm: Optional[str] = None,
        mel_scale: str = "htk",
380
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
381
    r"""Create a frequency bin conversion matrix.
Jason Lian's avatar
more  
Jason Lian committed
382

jamarshon's avatar
jamarshon committed
383
    Args:
384
        n_freqs (int): Number of frequencies to highlight/apply
engineerchuan's avatar
engineerchuan committed
385
386
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
387
        n_mels (int): Number of mel filterbanks
engineerchuan's avatar
engineerchuan committed
388
        sample_rate (int): Sample rate of the audio waveform
389
390
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
391
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Jason Lian's avatar
more  
Jason Lian committed
392

jamarshon's avatar
jamarshon committed
393
    Returns:
394
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
395
396
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
397
398
        size (..., ``n_freqs``), the applied result would be
        ``A * create_fb_matrix(A.size(-1), ...)``.
399
    """
400
401
402
403

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

404
    # freq bins
engineerchuan's avatar
engineerchuan committed
405
406
407
    # Equivalent filterbank construction by Librosa
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

Jason Lian's avatar
more  
Jason Lian committed
408
    # calculate mel freq bins
409
410
411
    m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
    m_max = _hz_to_mel(f_max, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
412
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
413
414
    f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
415
416
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
engineerchuan's avatar
engineerchuan committed
417
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_mels + 2)
Jason Lian's avatar
more  
Jason Lian committed
418
    # create overlapping triangles
419
    zero = torch.zeros(1)
420
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_mels)
421
422
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_mels)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))
Vincent QB's avatar
Vincent QB committed
423

424
    if norm is not None and norm == "slaney":
Vincent QB's avatar
Vincent QB committed
425
426
427
428
        # Slaney-style mel is scaled to be approx constant energy per channel
        enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
        fb *= enorm.unsqueeze(0)

429
430
431
432
433
434
435
    if (fb.max(dim=0).values == 0.).any():
        warnings.warn(
            "At least one mel filterbank has all zero values. "
            f"The value for `n_mels` ({n_mels}) may be set too high. "
            f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
        )

Jason Lian's avatar
more  
Jason Lian committed
436
437
438
    return fb


439
440
441
442
443
def create_dct(
        n_mfcc: int,
        n_mels: int,
        norm: Optional[str]
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
444
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
jamarshon's avatar
jamarshon committed
445
    normalized depending on norm.
Jason Lian's avatar
Jason Lian committed
446

jamarshon's avatar
jamarshon committed
447
    Args:
448
449
        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
450
        norm (str or None): Norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
451

jamarshon's avatar
jamarshon committed
452
    Returns:
453
        Tensor: The transformation matrix, to be right-multiplied to
454
        row-wise data of size (``n_mels``, ``n_mfcc``).
Jason Lian's avatar
more  
Jason Lian committed
455
456
    """
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
457
458
459
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
460
461
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
462
    else:
463
        assert norm == "ortho"
464
        dct[0] *= 1.0 / math.sqrt(2.0)
465
        dct *= math.sqrt(2.0 / float(n_mels))
466
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
467
468


469
470
471
472
def mu_law_encoding(
        x: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
473
    r"""Encode signal based on mu-law companding.  For more info see the
Jason Lian's avatar
Jason Lian committed
474
475
476
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
jamarshon's avatar
jamarshon committed
477
    returns a signal encoded with values from 0 to quantization_channels - 1.
Jason Lian's avatar
Jason Lian committed
478

jamarshon's avatar
jamarshon committed
479
    Args:
480
        x (Tensor): Input tensor
481
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
482

jamarshon's avatar
jamarshon committed
483
    Returns:
484
        Tensor: Input after mu-law encoding
Jason Lian's avatar
Jason Lian committed
485
    """
486
    mu = quantization_channels - 1.0
487
    if not x.is_floating_point():
488
489
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
490
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
491
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
492
493
494
    return x_mu


495
496
497
498
def mu_law_decoding(
        x_mu: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
499
    r"""Decode mu-law encoded signal.  For more info see the
Jason Lian's avatar
Jason Lian committed
500
501
502
503
504
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

jamarshon's avatar
jamarshon committed
505
    Args:
506
        x_mu (Tensor): Input tensor
507
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
508

jamarshon's avatar
jamarshon committed
509
    Returns:
510
        Tensor: Input after mu-law decoding
Jason Lian's avatar
Jason Lian committed
511
    """
512
    mu = quantization_channels - 1.0
513
    if not x_mu.is_floating_point():
514
515
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
516
517
    x = ((x_mu) / mu) * 2 - 1.0
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
Jason Lian's avatar
more  
Jason Lian committed
518
    return x
519
520


521
522
523
524
def complex_norm(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tensor:
525
    r"""Compute the norm of complex tensor input.
526
527

    Args:
528
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
529
        power (float): Power of the norm. (Default: `1.0`).
530
531

    Returns:
532
        Tensor: Power of the normed input tensor. Shape of `(..., )`
533
    """
534
535
536
537

    # Replace by torch.norm once issue is fixed
    # https://github.com/pytorch/pytorch/issues/34279
    return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
538
539


540
541
542
def angle(
        complex_tensor: Tensor
) -> Tensor:
543
544
545
    r"""Compute the angle of complex tensor input.

    Args:
546
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
547
548

    Return:
549
        Tensor: Angle of a complex tensor. Shape of `(..., )`
550
551
552
553
    """
    return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])


554
555
556
557
def magphase(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tuple[Tensor, Tensor]:
558
    r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
559
560

    Args:
561
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
562
563
564
        power (float): Power of the norm. (Default: `1.0`)

    Returns:
565
        (Tensor, Tensor): The magnitude and phase of the complex tensor
566
567
568
569
570
571
    """
    mag = complex_norm(complex_tensor, power)
    phase = angle(complex_tensor)
    return mag, phase


572
573
574
575
576
def phase_vocoder(
        complex_specgrams: Tensor,
        rate: float,
        phase_advance: Tensor
) -> Tensor:
577
    r"""Given a STFT tensor, speed up in time without modifying pitch by a
578
    factor of ``rate``.
Vincent QB's avatar
Vincent QB committed
579

580
    Args:
581
582
583
        complex_specgrams (Tensor):
            Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
            or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
584
        rate (float): Speed-up factor
585
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Vincent QB's avatar
Vincent QB committed
586

587
    Returns:
588
589
590
        Tensor:
            Stretched spectrogram. The resulting tensor is of the same dtype as the input
            spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Vincent QB's avatar
Vincent QB committed
591

592
593
594
595
596
597
598
599
600
601
602
603
    Example - With Tensor of complex dtype
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time)
        >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
        >>> rate = 1.3 # Speed up by 30%
        >>> phase_advance = torch.linspace(
        >>>    0, math.pi * hop_length, freq)[..., None]
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
        torch.Size([2, 1025, 231])

    Example - With Tensor of real dtype and extra dimension for complex field
604
605
606
607
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time, complex=2)
        >>> complex_specgrams = torch.randn(2, freq, 300, 2)
        >>> rate = 1.3 # Speed up by 30%
608
        >>> phase_advance = torch.linspace(
609
        >>>    0, math.pi * hop_length, freq)[..., None]
610
611
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
612
        torch.Size([2, 1025, 231, 2])
613
    """
614
615
616
617
618
619
620
621
622
623
624
625
    if rate == 1.0:
        return complex_specgrams

    if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2:
        raise ValueError(
            "complex_specgrams must be either native complex tensors or "
            "real valued tensors with shape (..., 2)")

    is_complex = complex_specgrams.is_complex()

    if not is_complex:
        complex_specgrams = torch.view_as_complex(complex_specgrams)
626

627
628
    # pack batch
    shape = complex_specgrams.size()
629
630
631
632
633
634
635
636
637
638
639
    complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))

    # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
    # Note torch.real is a view so it does not incur any memory copy.
    real_dtype = torch.real(complex_specgrams).dtype
    time_steps = torch.arange(
        0,
        complex_specgrams.size(-1),
        rate,
        device=complex_specgrams.device,
        dtype=real_dtype)
640

641
    alphas = time_steps % 1.0
642
    phase_0 = complex_specgrams[..., :1].angle()
643
644

    # Time Padding
645
    complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
646

647
    # (new_bins, freq, 2)
648
649
    complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
    complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
650

651
652
    angle_0 = complex_specgrams_0.angle()
    angle_1 = complex_specgrams_1.angle()
653

654
655
    norm_0 = complex_specgrams_0.abs()
    norm_1 = complex_specgrams_1.abs()
656
657
658
659
660
661

    phase = angle_1 - angle_0 - phase_advance
    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

    # Compute Phase Accum
    phase = phase + phase_advance
Vincent QB's avatar
Vincent QB committed
662
    phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
663
664
665
666
    phase_acc = torch.cumsum(phase, -1)

    mag = alphas * norm_1 + (1 - alphas) * norm_0

667
    complex_specgrams_stretch = torch.polar(mag, phase_acc)
668

669
    # unpack batch
670
    complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
671

672
673
    if not is_complex:
        return torch.view_as_real(complex_specgrams_stretch)
674
    return complex_specgrams_stretch
675
676


677
678
679
680
681
682
def mask_along_axis_iid(
        specgrams: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
683
684
685
686
687
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.

    Args:
Vincent QB's avatar
Vincent QB committed
688
        specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
689
690
691
692
693
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)

    Returns:
694
        Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
695
696
697
698
699
    """

    if axis != 2 and axis != 3:
        raise ValueError('Only Frequency and Time masking are supported')

700
701
702
703
704
    device = specgrams.device
    dtype = specgrams.dtype

    value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
    min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
705
706

    # Create broadcastable mask
707
708
709
    mask_start = min_value[..., None, None]
    mask_end = (min_value + value)[..., None, None]
    mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
710
711
712
713
714
715
716
717
718

    # Per batch example masking
    specgrams = specgrams.transpose(axis, -1)
    specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
    specgrams = specgrams.transpose(axis, -1)

    return specgrams


719
720
721
722
723
724
def mask_along_axis(
        specgram: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
725
726
727
728
729
730
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
    All examples will have the same mask interval.

    Args:
Vincent QB's avatar
Vincent QB committed
731
        specgram (Tensor): Real spectrogram (channel, freq, time)
732
733
734
735
736
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

    Returns:
737
        Tensor: Masked spectrogram of dimensions (channel, freq, time)
738
739
    """

740
741
    # pack batch
    shape = specgram.size()
742
    specgram = specgram.reshape([-1] + list(shape[-2:]))
743

744
745
746
747
748
749
750
751
752
753
754
755
756
757
    value = torch.rand(1) * mask_param
    min_value = torch.rand(1) * (specgram.size(axis) - value)

    mask_start = (min_value.long()).squeeze()
    mask_end = (min_value.long() + value.long()).squeeze()

    assert mask_end - mask_start < mask_param
    if axis == 1:
        specgram[:, mask_start:mask_end] = mask_value
    elif axis == 2:
        specgram[:, :, mask_start:mask_end] = mask_value
    else:
        raise ValueError('Only Frequency and Time masking are supported')

758
    # unpack batch
759
    specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
760

Vincent QB's avatar
Vincent QB committed
761
    return specgram
762
763


764
765
766
767
768
def compute_deltas(
        specgram: Tensor,
        win_length: int = 5,
        mode: str = "replicate"
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
769
770
771
    r"""Compute delta coefficients of a tensor, usually a spectrogram:

    .. math::
772
       d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
Vincent QB's avatar
Vincent QB committed
773
774
775

    where :math:`d_t` is the deltas at time :math:`t`,
    :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
776
    :math:`N` is ``(win_length-1)//2``.
Vincent QB's avatar
Vincent QB committed
777
778

    Args:
779
780
781
        specgram (Tensor): Tensor of audio of dimension (..., freq, time)
        win_length (int, optional): The window length used for computing delta (Default: ``5``)
        mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Vincent QB's avatar
Vincent QB committed
782
783

    Returns:
784
        Tensor: Tensor of deltas of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
785
786
787
788
789
790

    Example
        >>> specgram = torch.randn(1, 40, 1000)
        >>> delta = compute_deltas(specgram)
        >>> delta2 = compute_deltas(delta)
    """
791
792
    device = specgram.device
    dtype = specgram.dtype
Vincent QB's avatar
Vincent QB committed
793

Vincent QB's avatar
Vincent QB committed
794
795
    # pack batch
    shape = specgram.size()
796
    specgram = specgram.reshape(1, -1, shape[-1])
Vincent QB's avatar
Vincent QB committed
797

Vincent QB's avatar
Vincent QB committed
798
799
800
801
802
803
804
805
806
    assert win_length >= 3

    n = (win_length - 1) // 2

    # twice sum of integer squared
    denom = n * (n + 1) * (2 * n + 1) / 3

    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)

807
    kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
Vincent QB's avatar
Vincent QB committed
808

809
    output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
Vincent QB's avatar
Vincent QB committed
810
811

    # unpack batch
812
    output = output.reshape(shape)
Vincent QB's avatar
Vincent QB committed
813
814

    return output
Vincent QB's avatar
Vincent QB committed
815
816


817
818
819
820
821
822
def _compute_nccf(
        waveform: Tensor,
        sample_rate: int,
        frame_time: float,
        freq_low: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
823
824
825
826
827
828
829
830
831
    r"""
    Compute Normalized Cross-Correlation Function (NCCF).

    .. math::
        \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},

    where
    :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
    :math:`w` is the waveform,
832
    :math:`N` is the length of a frame,
Vincent QB's avatar
Vincent QB committed
833
834
835
836
837
838
839
    :math:`b_i` is the beginning of frame :math:`i`,
    :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
    """

    EPSILON = 10 ** (-9)

    # Number of lags to check
840
    lags = int(math.ceil(sample_rate / freq_low))
Vincent QB's avatar
Vincent QB committed
841
842
843
844

    frame_size = int(math.ceil(sample_rate * frame_time))

    waveform_length = waveform.size()[-1]
845
    num_of_frames = int(math.ceil(waveform_length / frame_size))
Vincent QB's avatar
Vincent QB committed
846
847
848
849
850
851
852

    p = lags + num_of_frames * frame_size - waveform_length
    waveform = torch.nn.functional.pad(waveform, (0, p))

    # Compute lags
    output_lag = []
    for lag in range(1, lags + 1):
853
854
        s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
        s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
Vincent QB's avatar
Vincent QB committed
855
856
857
858
859
860
861
862
863
864
865
866
867
868

        output_frames = (
            (s1 * s2).sum(-1)
            / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
            / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
        )

        output_lag.append(output_frames.unsqueeze(-1))

    nccf = torch.cat(output_lag, -1)

    return nccf


869
870
871
872
873
def _combine_max(
        a: Tuple[Tensor, Tensor],
        b: Tuple[Tensor, Tensor],
        thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
Vincent QB's avatar
Vincent QB committed
874
875
876
877
878
879
880
881
882
    """
    Take value from first if bigger than a multiplicative factor of the second, elementwise.
    """
    mask = (a[0] > thresh * b[0])
    values = mask * a[0] + ~mask * b[0]
    indices = mask * a[1] + ~mask * b[1]
    return values, indices


883
884
885
886
887
def _find_max_per_frame(
        nccf: Tensor,
        sample_rate: int,
        freq_high: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
888
889
890
891
892
893
894
895
    r"""
    For each frame, take the highest value of NCCF,
    apply centered median smoothing, and convert to frequency.

    Note: If the max among all the lags is very close
    to the first half of lags, then the latter is taken.
    """

896
    lag_min = int(math.ceil(sample_rate / freq_high))
Vincent QB's avatar
Vincent QB committed
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915

    # Find near enough max that is smallest

    best = torch.max(nccf[..., lag_min:], -1)

    half_size = nccf.shape[-1] // 2
    half = torch.max(nccf[..., lag_min:half_size], -1)

    best = _combine_max(half, best)
    indices = best[1]

    # Add back minimal lag
    indices += lag_min
    # Add 1 empirical calibration offset
    indices += 1

    return indices


916
917
918
919
def _median_smoothing(
        indices: Tensor,
        win_length: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
    r"""
    Apply median smoothing to the 1D tensor over the given window.
    """

    # Centered windowed
    pad_length = (win_length - 1) // 2

    # "replicate" padding in any dimension
    indices = torch.nn.functional.pad(
        indices, (pad_length, 0), mode="constant", value=0.
    )

    indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
    roll = indices.unfold(-1, win_length, 1)

    values, _ = torch.median(roll, -1)
    return values


def detect_pitch_frequency(
940
941
942
943
944
945
946
        waveform: Tensor,
        sample_rate: int,
        frame_time: float = 10 ** (-2),
        win_length: int = 30,
        freq_low: int = 85,
        freq_high: int = 3400,
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
947
948
949
950
951
    r"""Detect pitch frequency.

    It is implemented using normalized cross-correlation function and median smoothing.

    Args:
952
        waveform (Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
953
        sample_rate (int): The sample rate of the waveform (Hz)
954
955
956
957
        frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
        win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
        freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
        freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Vincent QB's avatar
Vincent QB committed
958
959

    Returns:
960
        Tensor: Tensor of freq of dimension (..., frame)
Vincent QB's avatar
Vincent QB committed
961
    """
Vincent QB's avatar
Vincent QB committed
962
    # pack batch
963
    shape = list(waveform.size())
964
    waveform = waveform.reshape([-1] + shape[-1:])
Vincent QB's avatar
Vincent QB committed
965

Vincent QB's avatar
Vincent QB committed
966
967
968
969
970
971
972
973
    nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
    indices = _find_max_per_frame(nccf, sample_rate, freq_high)
    indices = _median_smoothing(indices, win_length)

    # Convert indices to frequency
    EPSILON = 10 ** (-9)
    freq = sample_rate / (EPSILON + indices.to(torch.float))

Vincent QB's avatar
Vincent QB committed
974
    # unpack batch
975
    freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
Vincent QB's avatar
Vincent QB committed
976

Vincent QB's avatar
Vincent QB committed
977
    return freq
wanglong001's avatar
wanglong001 committed
978
979
980


def sliding_window_cmn(
981
    specgram: Tensor,
wanglong001's avatar
wanglong001 committed
982
983
984
985
986
987
988
989
990
    cmn_window: int = 600,
    min_cmn_window: int = 100,
    center: bool = False,
    norm_vars: bool = False,
) -> Tensor:
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
991
        specgram (Tensor): Tensor of audio of dimension (..., time, freq)
wanglong001's avatar
wanglong001 committed
992
993
994
995
996
997
998
999
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

    Returns:
1000
        Tensor: Tensor matching input shape (..., freq, time)
wanglong001's avatar
wanglong001 committed
1001
    """
1002
    input_shape = specgram.shape
1003
    num_frames, num_feats = input_shape[-2:]
1004
1005
    specgram = specgram.view(-1, num_frames, num_feats)
    num_channels = specgram.shape[0]
1006

1007
1008
    dtype = specgram.dtype
    device = specgram.device
wanglong001's avatar
wanglong001 committed
1009
    last_window_start = last_window_end = -1
1010
1011
    cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
    cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
1012
    cmn_specgram = torch.zeros(
1013
        num_channels, num_frames, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    for t in range(num_frames):
        window_start = 0
        window_end = 0
        if center:
            window_start = t - cmn_window // 2
            window_end = window_start + cmn_window
        else:
            window_start = t - cmn_window
            window_end = t + 1
        if window_start < 0:
            window_end -= window_start
            window_start = 0
        if not center:
            if window_end > t:
                window_end = max(t + 1, min_cmn_window)
        if window_end > num_frames:
            window_start -= (window_end - num_frames)
            window_end = num_frames
            if window_start < 0:
                window_start = 0
        if last_window_start == -1:
1035
            input_part = specgram[:, window_start: window_end - window_start, :]
1036
            cur_sum += torch.sum(input_part, 1)
wanglong001's avatar
wanglong001 committed
1037
            if norm_vars:
1038
                cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
wanglong001's avatar
wanglong001 committed
1039
1040
        else:
            if window_start > last_window_start:
1041
                frame_to_remove = specgram[:, last_window_start, :]
wanglong001's avatar
wanglong001 committed
1042
1043
1044
1045
                cur_sum -= frame_to_remove
                if norm_vars:
                    cur_sumsq -= (frame_to_remove ** 2)
            if window_end > last_window_end:
1046
                frame_to_add = specgram[:, last_window_end, :]
wanglong001's avatar
wanglong001 committed
1047
1048
1049
1050
1051
1052
                cur_sum += frame_to_add
                if norm_vars:
                    cur_sumsq += (frame_to_add ** 2)
        window_frames = window_end - window_start
        last_window_start = window_start
        last_window_end = window_end
1053
        cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
wanglong001's avatar
wanglong001 committed
1054
1055
        if norm_vars:
            if window_frames == 1:
1056
                cmn_specgram[:, t, :] = torch.zeros(
1057
                    num_channels, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1058
1059
1060
1061
1062
            else:
                variance = cur_sumsq
                variance = variance / window_frames
                variance -= ((cur_sum ** 2) / (window_frames ** 2))
                variance = torch.pow(variance, -0.5)
1063
                cmn_specgram[:, t, :] *= variance
1064

1065
    cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
1066
    if len(input_shape) == 2:
1067
1068
        cmn_specgram = cmn_specgram.squeeze(0)
    return cmn_specgram
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103


def spectral_centroid(
        waveform: Tensor,
        sample_rate: int,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
) -> Tensor:
    r"""
    Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        waveform (Tensor): Tensor of audio of dimension (..., time)
        sample_rate (int): Sample rate of the audio waveform
        pad (int): Two sided padding of signal
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size

    Returns:
        Tensor: Dimension (..., time)
    """
    specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
                           win_length=win_length, power=1., normalized=False)
    freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
                           device=specgram.device).reshape((-1, 1))
    freq_dim = -2
    return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
moto's avatar
moto committed
1104
1105


Caroline Chen's avatar
Caroline Chen committed
1106
@_mod_utils.requires_sox()
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
def apply_codec(
    waveform: Tensor,
    sample_rate: int,
    format: str,
    channels_first: bool = True,
    compression: Optional[float] = None,
    encoding: Optional[str] = None,
    bits_per_sample: Optional[int] = None,
) -> Tensor:
    r"""
Vincent QB's avatar
Vincent QB committed
1117
1118
    Apply codecs as a form of augmentation.

1119
    Args:
Vincent QB's avatar
Vincent QB committed
1120
1121
1122
        waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```.
        sample_rate (int): Sample rate of the audio waveform.
        format (str): File format.
1123
1124
1125
1126
        channels_first (bool):
            When True, both the input and output Tensor have dimension ``[channel, time]``.
            Otherwise, they have dimension ``[time, channel]``.
        compression (float): Used for formats other than WAV.
Vincent QB's avatar
Vincent QB committed
1127
            For mor details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1128
        encoding (str, optional): Changes the encoding for the supported formats.
Vincent QB's avatar
Vincent QB committed
1129
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1130
        bits_per_sample (int, optional): Changes the bit depth for the supported formats.
Vincent QB's avatar
Vincent QB committed
1131
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1132
1133
1134

    Returns:
        torch.Tensor: Resulting Tensor.
Vincent QB's avatar
Vincent QB committed
1135
        If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``.
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
    """
    bytes = io.BytesIO()
    torchaudio.backend.sox_io_backend.save(bytes,
                                           waveform,
                                           sample_rate,
                                           channels_first,
                                           compression,
                                           format,
                                           encoding,
                                           bits_per_sample
                                           )
    bytes.seek(0)
    augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
        bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
    return augmented


1153
@_mod_utils.requires_kaldi()
moto's avatar
moto committed
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
def compute_kaldi_pitch(
        waveform: torch.Tensor,
        sample_rate: float,
        frame_length: float = 25.0,
        frame_shift: float = 10.0,
        min_f0: float = 50,
        max_f0: float = 400,
        soft_min_f0: float = 10.0,
        penalty_factor: float = 0.1,
        lowpass_cutoff: float = 1000,
        resample_frequency: float = 4000,
        delta_pitch: float = 0.005,
        nccf_ballast: float = 7000,
        lowpass_filter_width: int = 1,
        upsample_filter_width: int = 5,
        max_frames_latency: int = 0,
        frames_per_chunk: int = 0,
        simulate_first_pass_online: bool = False,
        recompute_frame: int = 500,
        snip_edges: bool = True,
) -> torch.Tensor:
    """Extract pitch based on method described in [1].

    This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.

    Args:
        waveform (Tensor):
            The input waveform of shape `(..., time)`.
        sample_rate (float):
            Sample rate of `waveform`.
        frame_length (float, optional):
moto's avatar
moto committed
1185
            Frame length in milliseconds. (default: 25.0)
moto's avatar
moto committed
1186
        frame_shift (float, optional):
moto's avatar
moto committed
1187
            Frame shift in milliseconds. (default: 10.0)
moto's avatar
moto committed
1188
        min_f0 (float, optional):
moto's avatar
moto committed
1189
            Minimum F0 to search for (Hz)  (default: 50.0)
moto's avatar
moto committed
1190
        max_f0 (float, optional):
moto's avatar
moto committed
1191
            Maximum F0 to search for (Hz)  (default: 400.0)
moto's avatar
moto committed
1192
        soft_min_f0 (float, optional):
moto's avatar
moto committed
1193
            Minimum f0, applied in soft way, must not exceed min-f0  (default: 10.0)
moto's avatar
moto committed
1194
        penalty_factor (float, optional):
moto's avatar
moto committed
1195
            Cost factor for FO change.  (default: 0.1)
moto's avatar
moto committed
1196
        lowpass_cutoff (float, optional):
moto's avatar
moto committed
1197
            Cutoff frequency for LowPass filter (Hz) (default: 1000)
moto's avatar
moto committed
1198
1199
        resample_frequency (float, optional):
            Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
moto's avatar
moto committed
1200
            (default: 4000)
moto's avatar
moto committed
1201
        delta_pitch( float, optional):
moto's avatar
moto committed
1202
            Smallest relative change in pitch that our algorithm measures. (default: 0.005)
moto's avatar
moto committed
1203
        nccf_ballast (float, optional):
moto's avatar
moto committed
1204
            Increasing this factor reduces NCCF for quiet frames (default: 7000)
moto's avatar
moto committed
1205
1206
        lowpass_filter_width (int, optional):
            Integer that determines filter width of lowpass filter, more gives sharper filter.
moto's avatar
moto committed
1207
            (default: 1)
moto's avatar
moto committed
1208
        upsample_filter_width (int, optional):
moto's avatar
moto committed
1209
            Integer that determines filter width when upsampling NCCF. (default: 5)
moto's avatar
moto committed
1210
1211
1212
        max_frames_latency (int, optional):
            Maximum number of frames of latency that we allow pitch tracking to introduce into
            the feature processing (affects output only if ``frames_per_chunk > 0`` and
moto's avatar
moto committed
1213
            ``simulate_first_pass_online=True``) (default: 0)
moto's avatar
moto committed
1214
        frames_per_chunk (int, optional):
moto's avatar
moto committed
1215
            The number of frames used for energy normalization. (default: 0)
moto's avatar
moto committed
1216
1217
1218
        simulate_first_pass_online (bool, optional):
            If true, the function will output features that correspond to what an online decoder
            would see in the first pass of decoding -- not the final version of the features,
moto's avatar
moto committed
1219
            which is the default. (default: False)
moto's avatar
moto committed
1220
1221
1222
1223
1224
            Relevant if ``frames_per_chunk > 0``.
        recompute_frame (int, optional):
            Only relevant for compatibility with online pitch extraction.
            A non-critical parameter; the frame at which we recompute some of the forward pointers,
            after revising our estimate of the signal energy.
moto's avatar
moto committed
1225
            Relevant if ``frames_per_chunk > 0``. (default: 500)
moto's avatar
moto committed
1226
1227
1228
        snip_edges (bool, optional):
            If this is set to false, the incomplete frames near the ending edge won't be snipped,
            so that the number of frames is the file size divided by the frame-shift.
moto's avatar
moto committed
1229
            This makes different types of features give the same number of frames. (default: True)
moto's avatar
moto committed
1230
1231

    Returns:
moto's avatar
moto committed
1232
1233
       Tensor: Pitch feature. Shape: ``(batch, frames 2)`` where the last dimension
       corresponds to pitch and NCCF.
moto's avatar
moto committed
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255

    Reference:
        - A pitch extraction algorithm tuned for automatic speech recognition

          P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur

          2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),

          Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049.
    """
    shape = waveform.shape
    waveform = waveform.reshape(-1, shape[-1])
    result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
        waveform, sample_rate, frame_length, frame_shift,
        min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff,
        resample_frequency, delta_pitch, nccf_ballast,
        lowpass_filter_width, upsample_filter_width, max_frames_latency,
        frames_per_chunk, simulate_first_pass_online, recompute_frame,
        snip_edges,
    )
    result = result.reshape(shape[:-1] + result.shape[-2:])
    return result
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
                              device: torch.device, dtype: torch.dtype):
    assert lowpass_filter_width > 0
    kernels = []
    base_freq = min(orig_freq, new_freq)
    # This will perform antialiasing filtering by removing the highest frequencies.
    # At first I thought I only needed this when downsampling, but when upsampling
    # you will get edge artifacts without this, as the edge is equivalent to zero padding,
    # which will add high freq artifacts.
    base_freq *= 0.99

    # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
    # using the sinc interpolation formula:
    #   x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
    # We can then sample the function x(t) with a different sample rate:
    #    y[j] = x(j / new_freq)
    # or,
    #    y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

    # We see here that y[j] is the convolution of x[i] with a specific filter, for which
    # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
    # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
    # Indeed:
    # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
    #                 = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
    #                 = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
    # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
    # This will explain the F.conv1d after, with a stride of orig_freq.
    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
    # they will have a lot of almost zero values to the left or to the right...
    # There is probably a way to evaluate those filters more efficiently, but this is kept for
    # future work.
    idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

    for i in range(new_freq):
        t = (-i / new_freq + idx / orig_freq) * base_freq
        t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
        t *= math.pi
        # we do not use torch.hann_window here as we need to evaluate the window
        # at specific positions, not over a regular grid.
        window = torch.cos(t / lowpass_filter_width / 2)**2
        kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
        kernel.mul_(window)
        kernels.append(kernel)

    scale = base_freq / orig_freq
    return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample(
        waveform: Tensor,
        orig_freq: float,
        new_freq: float,
        lowpass_filter_width: int = 6
) -> Tensor:
    r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
    which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
    a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
    the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
    upsample/downsample the signal.

    https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
    https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56

    Args:
        waveform (Tensor): The input signal of dimension (..., time)
        orig_freq (float): The original frequency of the signal
        new_freq (float): The desired frequency
        lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
            but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)

    Returns:
        Tensor: The waveform at the new frequency of dimension (..., time).
    """
    # pack batch
    shape = waveform.size()
    waveform = waveform.view(-1, shape[-1])

    assert orig_freq > 0.0 and new_freq > 0.0

    orig_freq = int(orig_freq)
    new_freq = int(new_freq)
    gcd = math.gcd(orig_freq, new_freq)
    orig_freq = orig_freq // gcd
    new_freq = new_freq // gcd

    kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
                                              waveform.device, waveform.dtype)

    num_wavs, length = waveform.shape
    waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
    resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
    resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
    target_length = int(math.ceil(new_freq * length / orig_freq))
    resampled = resampled[..., :target_length]

    # unpack batch
    resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
    return resampled