functional.py 50.3 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import io
4
import math
moto's avatar
moto committed
5
import warnings
6
from typing import Optional, Tuple
Vincent QB's avatar
Vincent QB committed
7

Jason Lian's avatar
Jason Lian committed
8
import torch
9
from torch import Tensor
Caroline Chen's avatar
Caroline Chen committed
10
from torchaudio._internal import module_utils as _mod_utils
11
import torchaudio
Jason Lian's avatar
Jason Lian committed
12

Jason Lian's avatar
pre  
Jason Lian committed
13
__all__ = [
14
    "spectrogram",
15
    "griffinlim",
16
    "amplitude_to_DB",
17
18
    "DB_to_amplitude",
    "compute_deltas",
moto's avatar
moto committed
19
    "compute_kaldi_pitch",
20
21
    "create_fb_matrix",
    "create_dct",
22
23
24
    "compute_deltas",
    "detect_pitch_frequency",
    "DB_to_amplitude",
25
26
27
28
29
30
    "mu_law_encoding",
    "mu_law_decoding",
    "complex_norm",
    "angle",
    "magphase",
    "phase_vocoder",
31
    'mask_along_axis',
wanglong001's avatar
wanglong001 committed
32
33
    'mask_along_axis_iid',
    'sliding_window_cmn',
34
    "spectral_centroid",
35
    "apply_codec",
36
    "resample",
Jason Lian's avatar
pre  
Jason Lian committed
37
38
]

Vincent QB's avatar
Vincent QB committed
39

40
def spectrogram(
41
42
43
44
45
46
47
        waveform: Tensor,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: Optional[float],
48
49
50
        normalized: bool,
        center: bool = True,
        pad_mode: str = "reflect",
51
52
        onesided: bool = True,
        return_complex: bool = False,
53
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
54
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
55
    The spectrogram can be either magnitude-only or complex.
jamarshon's avatar
jamarshon committed
56
57

    Args:
58
        waveform (Tensor): Tensor of audio of dimension (..., time)
jamarshon's avatar
jamarshon committed
59
        pad (int): Two sided padding of signal
60
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
61
        n_fft (int): Size of FFT
62
63
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
64
        power (float or None): Exponent for the magnitude spectrogram,
jamarshon's avatar
jamarshon committed
65
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
66
            If None, then the complex spectrum is returned instead.
67
        normalized (bool): Whether to normalize by magnitude after stft
68
69
70
71
72
73
74
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
75
76
77
78
79
80
        return_complex (bool, optional):
            ``return_complex = True``, this function returns the resulting Tensor in
            complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
            dimension for real and imaginary parts. (see ``torch.view_as_real``).
            When ``power`` is provided, the value must be False, as the resulting
            Tensor represents real-valued power.
jamarshon's avatar
jamarshon committed
81
82

    Returns:
83
        Tensor: Dimension (..., freq, time), freq is
Vincent QB's avatar
Vincent QB committed
84
        ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
85
        Fourier bins, and time is the number of window hops (n_frame).
Jason Lian's avatar
Jason Lian committed
86
    """
87
88
89
90
91
92
93
94
    if power is None and not return_complex:
        warnings.warn(
            "The use of pseudo complex type in spectrogram is now deprecated."
            "Please migrate to native complex type by providing `return_complex=True`. "
            "Please refer to https://github.com/pytorch/audio/issues/1337 "
            "for more details about torchaudio's plan to migrate to native complex type."
        )

95
96
97
98
    if power is not None and return_complex:
        raise ValueError(
            'When `power` is provided, the return value is real-valued. '
            'Therefore, `return_complex` must be False.')
Jason Lian's avatar
Jason Lian committed
99
100

    if pad > 0:
101
        # TODO add "with torch.no_grad():" back when JIT supports it
102
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
103

104
105
    # pack batch
    shape = waveform.size()
106
    waveform = waveform.reshape(-1, shape[-1])
107

Jason Lian's avatar
Jason Lian committed
108
    # default values are consistent with librosa.core.spectrum._spectrogram
109
110
111
112
113
114
    spec_f = torch.stft(
        input=waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
115
116
        center=center,
        pad_mode=pad_mode,
117
        normalized=False,
118
        onesided=onesided,
119
        return_complex=True,
120
    )
121

122
    # unpack batch
123
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
124

125
    if normalized:
126
        spec_f /= window.pow(2.).sum().sqrt()
127
    if power is not None:
128
129
130
        if power == 1.0:
            return spec_f.abs()
        return spec_f.abs().pow(power)
131
132
133
    if not return_complex:
        return torch.view_as_real(spec_f)
    return spec_f
Jason Lian's avatar
more  
Jason Lian committed
134
135


136
137
138
139
140
141
142
143
144
145
def _get_complex_dtype(real_dtype: torch.dtype):
    if real_dtype == torch.double:
        return torch.cdouble
    if real_dtype == torch.float:
        return torch.cfloat
    if real_dtype == torch.half:
        return torch.complex32
    raise ValueError(f'Unexpected dtype {real_dtype}')


146
def griffinlim(
147
148
149
150
151
152
153
154
155
156
157
        specgram: Tensor,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: float,
        n_iter: int,
        momentum: float,
        length: Optional[int],
        rand_init: bool
) -> Tensor:
158
159
160
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

161
    *  [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
162
163
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.
164
    *  [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
165
166
167
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.
168
    *  [3] D. W. Griffin and J. S. Lim,
169
170
171
172
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
173
        specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
174
            where freq is ``n_fft // 2 + 1``.
175
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
176
177
178
179
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
180
        power (float): Exponent for the magnitude spectrogram,
181
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
182
183
184
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
185
186
187
            Values near 1 can lead to faster convergence, but above 1 may not converge.
        length (int or None): Array length of the expected output.
        rand_init (bool): Initializes phase randomly if True, to zero otherwise.
188
189

    Returns:
Vincent QB's avatar
Vincent QB committed
190
        torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
191
    """
192
193
    assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
    assert momentum >= 0, 'momentum={} < 0'.format(momentum)
194

Vincent QB's avatar
Vincent QB committed
195
196
    # pack batch
    shape = specgram.size()
197
    specgram = specgram.reshape([-1] + list(shape[-2:]))
Vincent QB's avatar
Vincent QB committed
198

Vincent QB's avatar
Vincent QB committed
199
    specgram = specgram.pow(1 / power)
200

201
    # initialize the phase
202
    if rand_init:
203
204
205
        angles = torch.rand(
            specgram.size(),
            dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
206
    else:
207
208
209
        angles = torch.full(
            specgram.size(), 1,
            dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
210
211

    # And initialize the previous iterate to 0
212
    tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device)
213
214
    for _ in range(n_iter):
        # Invert with our current estimate of the phases
Jeremy Chen's avatar
Jeremy Chen committed
215
216
217
218
219
        inverse = torch.istft(specgram * angles,
                              n_fft=n_fft,
                              hop_length=hop_length,
                              win_length=win_length,
                              window=window,
220
                              length=length)
221
222

        # Rebuild the spectrogram
223
224
225
226
227
228
229
230
231
232
233
        rebuilt = torch.stft(
            input=inverse,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=True,
            pad_mode='reflect',
            normalized=False,
            onesided=True,
            return_complex=True,
234
        )
235
236

        # Update our phase estimates
237
238
239
        angles = rebuilt
        if momentum:
            angles = angles - tprev.mul_(momentum / (1 + momentum))
240
241
242
243
        angles = angles.div(angles.abs().add(1e-16))

        # Store the previous iterate
        tprev = rebuilt
244
245

    # Return the final phase estimates
Jeremy Chen's avatar
Jeremy Chen committed
246
247
248
249
250
251
    waveform = torch.istft(specgram * angles,
                           n_fft=n_fft,
                           hop_length=hop_length,
                           win_length=win_length,
                           window=window,
                           length=length)
Vincent QB's avatar
Vincent QB committed
252
253

    # unpack batch
254
    waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
Vincent QB's avatar
Vincent QB committed
255
256

    return waveform
257
258


259
260
261
262
263
264
265
def amplitude_to_DB(
        x: Tensor,
        multiplier: float,
        amin: float,
        db_multiplier: float,
        top_db: Optional[float] = None
) -> Tensor:
266
    r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
267

268
269
    The output of each tensor in a batch depends on the maximum value of that tensor,
    and so may return different values for an audio clip split into snippets vs. a full clip.
270
271

    Args:
272
273
274
275

        x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
          the form `(..., freq, time)`. Batched inputs should include a channel dimension and
          have the form `(batch, channel, freq, time)`.
276
        multiplier (float): Use 10. for power and 20. for amplitude
277
        amin (float): Number to clamp ``x``
278
        db_multiplier (float): Log10(max(reference value and amin))
279
        top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
280
            is 80. (Default: ``None``)
281
282

    Returns:
283
        Tensor: Output tensor in decibel scale
284
    """
285
286
    x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
    x_db -= multiplier * db_multiplier
287
288

    if top_db is not None:
289
290
291
292
293
294
295
296
297
        # Expand batch
        shape = x_db.size()
        packed_channels = shape[-3] if x_db.dim() > 2 else 1
        x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])

        x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))

        # Repack batch
        x_db = x_db.reshape(shape)
298

299
    return x_db
300
301


302
303
304
305
306
def DB_to_amplitude(
        x: Tensor,
        ref: float,
        power: float
) -> Tensor:
307
308
309
    r"""Turn a tensor from the decibel scale to the power/amplitude scale.

    Args:
310
        x (Tensor): Input tensor before being converted to power/amplitude scale.
311
312
313
314
        ref (float): Reference which the output will be scaled by.
        power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.

    Returns:
315
        Tensor: Output tensor in power/amplitude scale.
316
317
318
319
    """
    return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)


320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
    r"""Convert Hz to Mels.

    Args:
        freqs (float): Frequencies in Hz
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        mels (float): Frequency in Mels
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 2595.0 * math.log10(1.0 + (freq / 700.0))

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    mels = (freq - f_min) / f_sp

    # Fill in the log-scale part
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    if freq >= min_log_hz:
        mels = min_log_mel + math.log(freq / min_log_hz) / logstep

    return mels


def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
    """Convert mel bin numbers to frequencies.

    Args:
        mels (Tensor): Mel frequencies
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        freqs (Tensor): Mels converted in Hz
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 700.0 * (10.0**(mels / 2595.0) - 1.0)

    # Fill in the linear scale
    f_min = 0.0
    f_sp = 200.0 / 3
    freqs = f_min + f_sp * mels

    # And now the nonlinear scale
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    log_t = (mels >= min_log_mel)
    freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

    return freqs


387
388
389
390
391
def create_fb_matrix(
        n_freqs: int,
        f_min: float,
        f_max: float,
        n_mels: int,
Vincent QB's avatar
Vincent QB committed
392
        sample_rate: int,
393
394
        norm: Optional[str] = None,
        mel_scale: str = "htk",
395
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
396
    r"""Create a frequency bin conversion matrix.
Jason Lian's avatar
more  
Jason Lian committed
397

jamarshon's avatar
jamarshon committed
398
    Args:
399
        n_freqs (int): Number of frequencies to highlight/apply
engineerchuan's avatar
engineerchuan committed
400
401
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
402
        n_mels (int): Number of mel filterbanks
engineerchuan's avatar
engineerchuan committed
403
        sample_rate (int): Sample rate of the audio waveform
404
405
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
406
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Jason Lian's avatar
more  
Jason Lian committed
407

jamarshon's avatar
jamarshon committed
408
    Returns:
409
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
410
411
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
412
413
        size (..., ``n_freqs``), the applied result would be
        ``A * create_fb_matrix(A.size(-1), ...)``.
414
    """
415
416
417
418

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

419
    # freq bins
engineerchuan's avatar
engineerchuan committed
420
421
422
    # Equivalent filterbank construction by Librosa
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

Jason Lian's avatar
more  
Jason Lian committed
423
    # calculate mel freq bins
424
425
426
    m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
    m_max = _hz_to_mel(f_max, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
427
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
428
429
    f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
430
431
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
engineerchuan's avatar
engineerchuan committed
432
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_mels + 2)
Jason Lian's avatar
more  
Jason Lian committed
433
    # create overlapping triangles
434
    zero = torch.zeros(1)
435
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_mels)
436
437
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_mels)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))
Vincent QB's avatar
Vincent QB committed
438

439
    if norm is not None and norm == "slaney":
Vincent QB's avatar
Vincent QB committed
440
441
442
443
        # Slaney-style mel is scaled to be approx constant energy per channel
        enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
        fb *= enorm.unsqueeze(0)

444
445
446
447
448
449
450
    if (fb.max(dim=0).values == 0.).any():
        warnings.warn(
            "At least one mel filterbank has all zero values. "
            f"The value for `n_mels` ({n_mels}) may be set too high. "
            f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
        )

Jason Lian's avatar
more  
Jason Lian committed
451
452
453
    return fb


454
455
456
457
458
def create_dct(
        n_mfcc: int,
        n_mels: int,
        norm: Optional[str]
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
459
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
jamarshon's avatar
jamarshon committed
460
    normalized depending on norm.
Jason Lian's avatar
Jason Lian committed
461

jamarshon's avatar
jamarshon committed
462
    Args:
463
464
        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
465
        norm (str or None): Norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
466

jamarshon's avatar
jamarshon committed
467
    Returns:
468
        Tensor: The transformation matrix, to be right-multiplied to
469
        row-wise data of size (``n_mels``, ``n_mfcc``).
Jason Lian's avatar
more  
Jason Lian committed
470
471
    """
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
472
473
474
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
475
476
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
477
    else:
478
        assert norm == "ortho"
479
        dct[0] *= 1.0 / math.sqrt(2.0)
480
        dct *= math.sqrt(2.0 / float(n_mels))
481
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
482
483


484
485
486
487
def mu_law_encoding(
        x: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
488
    r"""Encode signal based on mu-law companding.  For more info see the
Jason Lian's avatar
Jason Lian committed
489
490
491
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
jamarshon's avatar
jamarshon committed
492
    returns a signal encoded with values from 0 to quantization_channels - 1.
Jason Lian's avatar
Jason Lian committed
493

jamarshon's avatar
jamarshon committed
494
    Args:
495
        x (Tensor): Input tensor
496
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
497

jamarshon's avatar
jamarshon committed
498
    Returns:
499
        Tensor: Input after mu-law encoding
Jason Lian's avatar
Jason Lian committed
500
    """
501
    mu = quantization_channels - 1.0
502
    if not x.is_floating_point():
503
504
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
505
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
506
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
507
508
509
    return x_mu


510
511
512
513
def mu_law_decoding(
        x_mu: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
514
    r"""Decode mu-law encoded signal.  For more info see the
Jason Lian's avatar
Jason Lian committed
515
516
517
518
519
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

jamarshon's avatar
jamarshon committed
520
    Args:
521
        x_mu (Tensor): Input tensor
522
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
523

jamarshon's avatar
jamarshon committed
524
    Returns:
525
        Tensor: Input after mu-law decoding
Jason Lian's avatar
Jason Lian committed
526
    """
527
    mu = quantization_channels - 1.0
528
    if not x_mu.is_floating_point():
529
530
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
531
532
    x = ((x_mu) / mu) * 2 - 1.0
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
Jason Lian's avatar
more  
Jason Lian committed
533
    return x
534
535


536
537
538
539
540
541
@_mod_utils.deprecated(
    "Please convert the input Tensor to complex type with `torch.view_as_complex` then "
    "use `torch.abs`. "
    "Please refer to https://github.com/pytorch/audio/issues/1337 "
    "for more details about torchaudio's plan to migrate to native complex type."
)
542
543
544
545
def complex_norm(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tensor:
546
    r"""Compute the norm of complex tensor input.
547
548

    Args:
549
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
550
        power (float): Power of the norm. (Default: `1.0`).
551
552

    Returns:
553
        Tensor: Power of the normed input tensor. Shape of `(..., )`
554
    """
555
556
557
558

    # Replace by torch.norm once issue is fixed
    # https://github.com/pytorch/pytorch/issues/34279
    return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
559
560


561
562
563
564
565
566
@_mod_utils.deprecated(
    "Please convert the input Tensor to complex type with `torch.view_as_complex` then "
    "use `torch.angle`. "
    "Please refer to https://github.com/pytorch/audio/issues/1337 "
    "for more details about torchaudio's plan to migrate to native complex type."
)
567
568
569
def angle(
        complex_tensor: Tensor
) -> Tensor:
570
571
572
    r"""Compute the angle of complex tensor input.

    Args:
573
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
574
575

    Return:
576
        Tensor: Angle of a complex tensor. Shape of `(..., )`
577
578
579
580
    """
    return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])


581
582
583
584
def magphase(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tuple[Tensor, Tensor]:
585
    r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
586
587

    Args:
588
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
589
590
591
        power (float): Power of the norm. (Default: `1.0`)

    Returns:
592
        (Tensor, Tensor): The magnitude and phase of the complex tensor
593
594
595
596
597
598
    """
    mag = complex_norm(complex_tensor, power)
    phase = angle(complex_tensor)
    return mag, phase


599
600
601
602
603
def phase_vocoder(
        complex_specgrams: Tensor,
        rate: float,
        phase_advance: Tensor
) -> Tensor:
604
    r"""Given a STFT tensor, speed up in time without modifying pitch by a
605
    factor of ``rate``.
Vincent QB's avatar
Vincent QB committed
606

607
    Args:
608
609
610
        complex_specgrams (Tensor):
            Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
            or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
611
        rate (float): Speed-up factor
612
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Vincent QB's avatar
Vincent QB committed
613

614
    Returns:
615
616
617
        Tensor:
            Stretched spectrogram. The resulting tensor is of the same dtype as the input
            spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Vincent QB's avatar
Vincent QB committed
618

619
620
621
622
623
624
625
626
627
628
629
630
    Example - With Tensor of complex dtype
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time)
        >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
        >>> rate = 1.3 # Speed up by 30%
        >>> phase_advance = torch.linspace(
        >>>    0, math.pi * hop_length, freq)[..., None]
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
        torch.Size([2, 1025, 231])

    Example - With Tensor of real dtype and extra dimension for complex field
631
632
633
634
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time, complex=2)
        >>> complex_specgrams = torch.randn(2, freq, 300, 2)
        >>> rate = 1.3 # Speed up by 30%
635
        >>> phase_advance = torch.linspace(
636
        >>>    0, math.pi * hop_length, freq)[..., None]
637
638
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
639
        torch.Size([2, 1025, 231, 2])
640
    """
641
642
643
    if rate == 1.0:
        return complex_specgrams

644
645
646
647
648
649
650
651
652
653
654
655
656
    if not complex_specgrams.is_complex():
        warnings.warn(
            "The use of pseudo complex type in `torchaudio.functional.phase_vocoder` and "
            "`torchaudio.transforms.TimeStretch` is now deprecated."
            "Please migrate to native complex type by converting the input tensor with "
            "`torch.view_as_complex`. "
            "Please refer to https://github.com/pytorch/audio/issues/1337 "
            "for more details about torchaudio's plan to migrate to native complex type."
        )
        if complex_specgrams.size(-1) != 2:
            raise ValueError(
                "complex_specgrams must be either native complex tensors or "
                "real valued tensors with shape (..., 2)")
657
658
659
660
661

    is_complex = complex_specgrams.is_complex()

    if not is_complex:
        complex_specgrams = torch.view_as_complex(complex_specgrams)
662

663
664
    # pack batch
    shape = complex_specgrams.size()
665
666
667
668
669
670
671
672
673
674
675
    complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))

    # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
    # Note torch.real is a view so it does not incur any memory copy.
    real_dtype = torch.real(complex_specgrams).dtype
    time_steps = torch.arange(
        0,
        complex_specgrams.size(-1),
        rate,
        device=complex_specgrams.device,
        dtype=real_dtype)
676

677
    alphas = time_steps % 1.0
678
    phase_0 = complex_specgrams[..., :1].angle()
679
680

    # Time Padding
681
    complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
682

683
    # (new_bins, freq, 2)
684
685
    complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
    complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
686

687
688
    angle_0 = complex_specgrams_0.angle()
    angle_1 = complex_specgrams_1.angle()
689

690
691
    norm_0 = complex_specgrams_0.abs()
    norm_1 = complex_specgrams_1.abs()
692
693
694
695
696
697

    phase = angle_1 - angle_0 - phase_advance
    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

    # Compute Phase Accum
    phase = phase + phase_advance
Vincent QB's avatar
Vincent QB committed
698
    phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
699
700
701
702
    phase_acc = torch.cumsum(phase, -1)

    mag = alphas * norm_1 + (1 - alphas) * norm_0

703
    complex_specgrams_stretch = torch.polar(mag, phase_acc)
704

705
    # unpack batch
706
    complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
707

708
709
    if not is_complex:
        return torch.view_as_real(complex_specgrams_stretch)
710
    return complex_specgrams_stretch
711
712


713
714
715
716
717
718
def mask_along_axis_iid(
        specgrams: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
719
720
721
722
723
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.

    Args:
Vincent QB's avatar
Vincent QB committed
724
        specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
725
726
727
728
729
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)

    Returns:
730
        Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
731
732
733
734
735
    """

    if axis != 2 and axis != 3:
        raise ValueError('Only Frequency and Time masking are supported')

736
737
738
739
740
    device = specgrams.device
    dtype = specgrams.dtype

    value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
    min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
741
742

    # Create broadcastable mask
743
744
745
    mask_start = min_value[..., None, None]
    mask_end = (min_value + value)[..., None, None]
    mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
746
747
748
749
750
751
752
753
754

    # Per batch example masking
    specgrams = specgrams.transpose(axis, -1)
    specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
    specgrams = specgrams.transpose(axis, -1)

    return specgrams


755
756
757
758
759
760
def mask_along_axis(
        specgram: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
761
762
763
764
765
766
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
    All examples will have the same mask interval.

    Args:
Vincent QB's avatar
Vincent QB committed
767
        specgram (Tensor): Real spectrogram (channel, freq, time)
768
769
770
771
772
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

    Returns:
773
        Tensor: Masked spectrogram of dimensions (channel, freq, time)
774
775
    """

776
777
    # pack batch
    shape = specgram.size()
778
    specgram = specgram.reshape([-1] + list(shape[-2:]))
779

780
781
782
783
784
785
786
787
788
789
790
791
792
793
    value = torch.rand(1) * mask_param
    min_value = torch.rand(1) * (specgram.size(axis) - value)

    mask_start = (min_value.long()).squeeze()
    mask_end = (min_value.long() + value.long()).squeeze()

    assert mask_end - mask_start < mask_param
    if axis == 1:
        specgram[:, mask_start:mask_end] = mask_value
    elif axis == 2:
        specgram[:, :, mask_start:mask_end] = mask_value
    else:
        raise ValueError('Only Frequency and Time masking are supported')

794
    # unpack batch
795
    specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
796

Vincent QB's avatar
Vincent QB committed
797
    return specgram
798
799


800
801
802
803
804
def compute_deltas(
        specgram: Tensor,
        win_length: int = 5,
        mode: str = "replicate"
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
805
806
807
    r"""Compute delta coefficients of a tensor, usually a spectrogram:

    .. math::
808
       d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
Vincent QB's avatar
Vincent QB committed
809
810
811

    where :math:`d_t` is the deltas at time :math:`t`,
    :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
812
    :math:`N` is ``(win_length-1)//2``.
Vincent QB's avatar
Vincent QB committed
813
814

    Args:
815
816
817
        specgram (Tensor): Tensor of audio of dimension (..., freq, time)
        win_length (int, optional): The window length used for computing delta (Default: ``5``)
        mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Vincent QB's avatar
Vincent QB committed
818
819

    Returns:
820
        Tensor: Tensor of deltas of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
821
822
823
824
825
826

    Example
        >>> specgram = torch.randn(1, 40, 1000)
        >>> delta = compute_deltas(specgram)
        >>> delta2 = compute_deltas(delta)
    """
827
828
    device = specgram.device
    dtype = specgram.dtype
Vincent QB's avatar
Vincent QB committed
829

Vincent QB's avatar
Vincent QB committed
830
831
    # pack batch
    shape = specgram.size()
832
    specgram = specgram.reshape(1, -1, shape[-1])
Vincent QB's avatar
Vincent QB committed
833

Vincent QB's avatar
Vincent QB committed
834
835
836
837
838
839
840
841
842
    assert win_length >= 3

    n = (win_length - 1) // 2

    # twice sum of integer squared
    denom = n * (n + 1) * (2 * n + 1) / 3

    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)

843
    kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
Vincent QB's avatar
Vincent QB committed
844

845
    output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
Vincent QB's avatar
Vincent QB committed
846
847

    # unpack batch
848
    output = output.reshape(shape)
Vincent QB's avatar
Vincent QB committed
849
850

    return output
Vincent QB's avatar
Vincent QB committed
851
852


853
854
855
856
857
858
def _compute_nccf(
        waveform: Tensor,
        sample_rate: int,
        frame_time: float,
        freq_low: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
859
860
861
862
863
864
865
866
867
    r"""
    Compute Normalized Cross-Correlation Function (NCCF).

    .. math::
        \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},

    where
    :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
    :math:`w` is the waveform,
868
    :math:`N` is the length of a frame,
Vincent QB's avatar
Vincent QB committed
869
870
871
872
873
874
875
    :math:`b_i` is the beginning of frame :math:`i`,
    :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
    """

    EPSILON = 10 ** (-9)

    # Number of lags to check
876
    lags = int(math.ceil(sample_rate / freq_low))
Vincent QB's avatar
Vincent QB committed
877
878
879
880

    frame_size = int(math.ceil(sample_rate * frame_time))

    waveform_length = waveform.size()[-1]
881
    num_of_frames = int(math.ceil(waveform_length / frame_size))
Vincent QB's avatar
Vincent QB committed
882
883
884
885
886
887
888

    p = lags + num_of_frames * frame_size - waveform_length
    waveform = torch.nn.functional.pad(waveform, (0, p))

    # Compute lags
    output_lag = []
    for lag in range(1, lags + 1):
889
890
        s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
        s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
Vincent QB's avatar
Vincent QB committed
891
892
893
894
895
896
897
898
899
900
901
902
903
904

        output_frames = (
            (s1 * s2).sum(-1)
            / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
            / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
        )

        output_lag.append(output_frames.unsqueeze(-1))

    nccf = torch.cat(output_lag, -1)

    return nccf


905
906
907
908
909
def _combine_max(
        a: Tuple[Tensor, Tensor],
        b: Tuple[Tensor, Tensor],
        thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
Vincent QB's avatar
Vincent QB committed
910
911
912
913
914
915
916
917
918
    """
    Take value from first if bigger than a multiplicative factor of the second, elementwise.
    """
    mask = (a[0] > thresh * b[0])
    values = mask * a[0] + ~mask * b[0]
    indices = mask * a[1] + ~mask * b[1]
    return values, indices


919
920
921
922
923
def _find_max_per_frame(
        nccf: Tensor,
        sample_rate: int,
        freq_high: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
924
925
926
927
928
929
930
931
    r"""
    For each frame, take the highest value of NCCF,
    apply centered median smoothing, and convert to frequency.

    Note: If the max among all the lags is very close
    to the first half of lags, then the latter is taken.
    """

932
    lag_min = int(math.ceil(sample_rate / freq_high))
Vincent QB's avatar
Vincent QB committed
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951

    # Find near enough max that is smallest

    best = torch.max(nccf[..., lag_min:], -1)

    half_size = nccf.shape[-1] // 2
    half = torch.max(nccf[..., lag_min:half_size], -1)

    best = _combine_max(half, best)
    indices = best[1]

    # Add back minimal lag
    indices += lag_min
    # Add 1 empirical calibration offset
    indices += 1

    return indices


952
953
954
955
def _median_smoothing(
        indices: Tensor,
        win_length: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
    r"""
    Apply median smoothing to the 1D tensor over the given window.
    """

    # Centered windowed
    pad_length = (win_length - 1) // 2

    # "replicate" padding in any dimension
    indices = torch.nn.functional.pad(
        indices, (pad_length, 0), mode="constant", value=0.
    )

    indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
    roll = indices.unfold(-1, win_length, 1)

    values, _ = torch.median(roll, -1)
    return values


def detect_pitch_frequency(
976
977
978
979
980
981
982
        waveform: Tensor,
        sample_rate: int,
        frame_time: float = 10 ** (-2),
        win_length: int = 30,
        freq_low: int = 85,
        freq_high: int = 3400,
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
983
984
985
986
987
    r"""Detect pitch frequency.

    It is implemented using normalized cross-correlation function and median smoothing.

    Args:
988
        waveform (Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
989
        sample_rate (int): The sample rate of the waveform (Hz)
990
991
992
993
        frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
        win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
        freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
        freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Vincent QB's avatar
Vincent QB committed
994
995

    Returns:
996
        Tensor: Tensor of freq of dimension (..., frame)
Vincent QB's avatar
Vincent QB committed
997
    """
Vincent QB's avatar
Vincent QB committed
998
    # pack batch
999
    shape = list(waveform.size())
1000
    waveform = waveform.reshape([-1] + shape[-1:])
Vincent QB's avatar
Vincent QB committed
1001

Vincent QB's avatar
Vincent QB committed
1002
1003
1004
1005
1006
1007
1008
1009
    nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
    indices = _find_max_per_frame(nccf, sample_rate, freq_high)
    indices = _median_smoothing(indices, win_length)

    # Convert indices to frequency
    EPSILON = 10 ** (-9)
    freq = sample_rate / (EPSILON + indices.to(torch.float))

Vincent QB's avatar
Vincent QB committed
1010
    # unpack batch
1011
    freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
Vincent QB's avatar
Vincent QB committed
1012

Vincent QB's avatar
Vincent QB committed
1013
    return freq
wanglong001's avatar
wanglong001 committed
1014
1015
1016


def sliding_window_cmn(
1017
    specgram: Tensor,
wanglong001's avatar
wanglong001 committed
1018
1019
1020
1021
1022
1023
1024
1025
1026
    cmn_window: int = 600,
    min_cmn_window: int = 100,
    center: bool = False,
    norm_vars: bool = False,
) -> Tensor:
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
1027
        specgram (Tensor): Tensor of audio of dimension (..., time, freq)
wanglong001's avatar
wanglong001 committed
1028
1029
1030
1031
1032
1033
1034
1035
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

    Returns:
1036
        Tensor: Tensor matching input shape (..., freq, time)
wanglong001's avatar
wanglong001 committed
1037
    """
1038
    input_shape = specgram.shape
1039
    num_frames, num_feats = input_shape[-2:]
1040
1041
    specgram = specgram.view(-1, num_frames, num_feats)
    num_channels = specgram.shape[0]
1042

1043
1044
    dtype = specgram.dtype
    device = specgram.device
wanglong001's avatar
wanglong001 committed
1045
    last_window_start = last_window_end = -1
1046
1047
    cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
    cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
1048
    cmn_specgram = torch.zeros(
1049
        num_channels, num_frames, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
    for t in range(num_frames):
        window_start = 0
        window_end = 0
        if center:
            window_start = t - cmn_window // 2
            window_end = window_start + cmn_window
        else:
            window_start = t - cmn_window
            window_end = t + 1
        if window_start < 0:
            window_end -= window_start
            window_start = 0
        if not center:
            if window_end > t:
                window_end = max(t + 1, min_cmn_window)
        if window_end > num_frames:
            window_start -= (window_end - num_frames)
            window_end = num_frames
            if window_start < 0:
                window_start = 0
        if last_window_start == -1:
1071
            input_part = specgram[:, window_start: window_end - window_start, :]
1072
            cur_sum += torch.sum(input_part, 1)
wanglong001's avatar
wanglong001 committed
1073
            if norm_vars:
1074
                cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
wanglong001's avatar
wanglong001 committed
1075
1076
        else:
            if window_start > last_window_start:
1077
                frame_to_remove = specgram[:, last_window_start, :]
wanglong001's avatar
wanglong001 committed
1078
1079
1080
1081
                cur_sum -= frame_to_remove
                if norm_vars:
                    cur_sumsq -= (frame_to_remove ** 2)
            if window_end > last_window_end:
1082
                frame_to_add = specgram[:, last_window_end, :]
wanglong001's avatar
wanglong001 committed
1083
1084
1085
1086
1087
1088
                cur_sum += frame_to_add
                if norm_vars:
                    cur_sumsq += (frame_to_add ** 2)
        window_frames = window_end - window_start
        last_window_start = window_start
        last_window_end = window_end
1089
        cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
wanglong001's avatar
wanglong001 committed
1090
1091
        if norm_vars:
            if window_frames == 1:
1092
                cmn_specgram[:, t, :] = torch.zeros(
1093
                    num_channels, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1094
1095
1096
1097
1098
            else:
                variance = cur_sumsq
                variance = variance / window_frames
                variance -= ((cur_sum ** 2) / (window_frames ** 2))
                variance = torch.pow(variance, -0.5)
1099
                cmn_specgram[:, t, :] *= variance
1100

1101
    cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
1102
    if len(input_shape) == 2:
1103
1104
        cmn_specgram = cmn_specgram.squeeze(0)
    return cmn_specgram
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139


def spectral_centroid(
        waveform: Tensor,
        sample_rate: int,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
) -> Tensor:
    r"""
    Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        waveform (Tensor): Tensor of audio of dimension (..., time)
        sample_rate (int): Sample rate of the audio waveform
        pad (int): Two sided padding of signal
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size

    Returns:
        Tensor: Dimension (..., time)
    """
    specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
                           win_length=win_length, power=1., normalized=False)
    freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
                           device=specgram.device).reshape((-1, 1))
    freq_dim = -2
    return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
moto's avatar
moto committed
1140
1141


Caroline Chen's avatar
Caroline Chen committed
1142
@_mod_utils.requires_sox()
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
def apply_codec(
    waveform: Tensor,
    sample_rate: int,
    format: str,
    channels_first: bool = True,
    compression: Optional[float] = None,
    encoding: Optional[str] = None,
    bits_per_sample: Optional[int] = None,
) -> Tensor:
    r"""
Vincent QB's avatar
Vincent QB committed
1153
1154
    Apply codecs as a form of augmentation.

1155
    Args:
Vincent QB's avatar
Vincent QB committed
1156
1157
1158
        waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```.
        sample_rate (int): Sample rate of the audio waveform.
        format (str): File format.
1159
1160
1161
1162
        channels_first (bool):
            When True, both the input and output Tensor have dimension ``[channel, time]``.
            Otherwise, they have dimension ``[time, channel]``.
        compression (float): Used for formats other than WAV.
Vincent QB's avatar
Vincent QB committed
1163
            For mor details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1164
        encoding (str, optional): Changes the encoding for the supported formats.
Vincent QB's avatar
Vincent QB committed
1165
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1166
        bits_per_sample (int, optional): Changes the bit depth for the supported formats.
Vincent QB's avatar
Vincent QB committed
1167
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1168
1169
1170

    Returns:
        torch.Tensor: Resulting Tensor.
Vincent QB's avatar
Vincent QB committed
1171
        If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``.
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
    """
    bytes = io.BytesIO()
    torchaudio.backend.sox_io_backend.save(bytes,
                                           waveform,
                                           sample_rate,
                                           channels_first,
                                           compression,
                                           format,
                                           encoding,
                                           bits_per_sample
                                           )
    bytes.seek(0)
    augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
        bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
    return augmented


1189
@_mod_utils.requires_kaldi()
moto's avatar
moto committed
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
def compute_kaldi_pitch(
        waveform: torch.Tensor,
        sample_rate: float,
        frame_length: float = 25.0,
        frame_shift: float = 10.0,
        min_f0: float = 50,
        max_f0: float = 400,
        soft_min_f0: float = 10.0,
        penalty_factor: float = 0.1,
        lowpass_cutoff: float = 1000,
        resample_frequency: float = 4000,
        delta_pitch: float = 0.005,
        nccf_ballast: float = 7000,
        lowpass_filter_width: int = 1,
        upsample_filter_width: int = 5,
        max_frames_latency: int = 0,
        frames_per_chunk: int = 0,
        simulate_first_pass_online: bool = False,
        recompute_frame: int = 500,
        snip_edges: bool = True,
) -> torch.Tensor:
    """Extract pitch based on method described in [1].

    This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.

    Args:
        waveform (Tensor):
            The input waveform of shape `(..., time)`.
        sample_rate (float):
            Sample rate of `waveform`.
        frame_length (float, optional):
moto's avatar
moto committed
1221
            Frame length in milliseconds. (default: 25.0)
moto's avatar
moto committed
1222
        frame_shift (float, optional):
moto's avatar
moto committed
1223
            Frame shift in milliseconds. (default: 10.0)
moto's avatar
moto committed
1224
        min_f0 (float, optional):
moto's avatar
moto committed
1225
            Minimum F0 to search for (Hz)  (default: 50.0)
moto's avatar
moto committed
1226
        max_f0 (float, optional):
moto's avatar
moto committed
1227
            Maximum F0 to search for (Hz)  (default: 400.0)
moto's avatar
moto committed
1228
        soft_min_f0 (float, optional):
moto's avatar
moto committed
1229
            Minimum f0, applied in soft way, must not exceed min-f0  (default: 10.0)
moto's avatar
moto committed
1230
        penalty_factor (float, optional):
moto's avatar
moto committed
1231
            Cost factor for FO change.  (default: 0.1)
moto's avatar
moto committed
1232
        lowpass_cutoff (float, optional):
moto's avatar
moto committed
1233
            Cutoff frequency for LowPass filter (Hz) (default: 1000)
moto's avatar
moto committed
1234
1235
        resample_frequency (float, optional):
            Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
moto's avatar
moto committed
1236
            (default: 4000)
moto's avatar
moto committed
1237
        delta_pitch( float, optional):
moto's avatar
moto committed
1238
            Smallest relative change in pitch that our algorithm measures. (default: 0.005)
moto's avatar
moto committed
1239
        nccf_ballast (float, optional):
moto's avatar
moto committed
1240
            Increasing this factor reduces NCCF for quiet frames (default: 7000)
moto's avatar
moto committed
1241
1242
        lowpass_filter_width (int, optional):
            Integer that determines filter width of lowpass filter, more gives sharper filter.
moto's avatar
moto committed
1243
            (default: 1)
moto's avatar
moto committed
1244
        upsample_filter_width (int, optional):
moto's avatar
moto committed
1245
            Integer that determines filter width when upsampling NCCF. (default: 5)
moto's avatar
moto committed
1246
1247
1248
        max_frames_latency (int, optional):
            Maximum number of frames of latency that we allow pitch tracking to introduce into
            the feature processing (affects output only if ``frames_per_chunk > 0`` and
moto's avatar
moto committed
1249
            ``simulate_first_pass_online=True``) (default: 0)
moto's avatar
moto committed
1250
        frames_per_chunk (int, optional):
moto's avatar
moto committed
1251
            The number of frames used for energy normalization. (default: 0)
moto's avatar
moto committed
1252
1253
1254
        simulate_first_pass_online (bool, optional):
            If true, the function will output features that correspond to what an online decoder
            would see in the first pass of decoding -- not the final version of the features,
moto's avatar
moto committed
1255
            which is the default. (default: False)
moto's avatar
moto committed
1256
1257
1258
1259
1260
            Relevant if ``frames_per_chunk > 0``.
        recompute_frame (int, optional):
            Only relevant for compatibility with online pitch extraction.
            A non-critical parameter; the frame at which we recompute some of the forward pointers,
            after revising our estimate of the signal energy.
moto's avatar
moto committed
1261
            Relevant if ``frames_per_chunk > 0``. (default: 500)
moto's avatar
moto committed
1262
1263
1264
        snip_edges (bool, optional):
            If this is set to false, the incomplete frames near the ending edge won't be snipped,
            so that the number of frames is the file size divided by the frame-shift.
moto's avatar
moto committed
1265
            This makes different types of features give the same number of frames. (default: True)
moto's avatar
moto committed
1266
1267

    Returns:
moto's avatar
moto committed
1268
1269
       Tensor: Pitch feature. Shape: ``(batch, frames 2)`` where the last dimension
       corresponds to pitch and NCCF.
moto's avatar
moto committed
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291

    Reference:
        - A pitch extraction algorithm tuned for automatic speech recognition

          P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur

          2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),

          Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049.
    """
    shape = waveform.shape
    waveform = waveform.reshape(-1, shape[-1])
    result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
        waveform, sample_rate, frame_length, frame_shift,
        min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff,
        resample_frequency, delta_pitch, nccf_ballast,
        lowpass_filter_width, upsample_filter_width, max_frames_latency,
        frames_per_chunk, simulate_first_pass_online, recompute_frame,
        snip_edges,
    )
    result = result.reshape(shape[:-1] + result.shape[-2:])
    return result
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
                              device: torch.device, dtype: torch.dtype):
    assert lowpass_filter_width > 0
    kernels = []
    base_freq = min(orig_freq, new_freq)
    # This will perform antialiasing filtering by removing the highest frequencies.
    # At first I thought I only needed this when downsampling, but when upsampling
    # you will get edge artifacts without this, as the edge is equivalent to zero padding,
    # which will add high freq artifacts.
    base_freq *= 0.99

    # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
    # using the sinc interpolation formula:
    #   x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
    # We can then sample the function x(t) with a different sample rate:
    #    y[j] = x(j / new_freq)
    # or,
    #    y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

    # We see here that y[j] is the convolution of x[i] with a specific filter, for which
    # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
    # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
    # Indeed:
    # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
    #                 = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
    #                 = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
    # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
    # This will explain the F.conv1d after, with a stride of orig_freq.
    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
    # they will have a lot of almost zero values to the left or to the right...
    # There is probably a way to evaluate those filters more efficiently, but this is kept for
    # future work.
    idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

    for i in range(new_freq):
        t = (-i / new_freq + idx / orig_freq) * base_freq
        t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
        t *= math.pi
        # we do not use torch.hann_window here as we need to evaluate the window
        # at specific positions, not over a regular grid.
        window = torch.cos(t / lowpass_filter_width / 2)**2
        kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
        kernel.mul_(window)
        kernels.append(kernel)

    scale = base_freq / orig_freq
    return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample(
        waveform: Tensor,
        orig_freq: float,
        new_freq: float,
        lowpass_filter_width: int = 6
) -> Tensor:
    r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
    which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
    a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
    the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
    upsample/downsample the signal.

    https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
    https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56

    Args:
        waveform (Tensor): The input signal of dimension (..., time)
        orig_freq (float): The original frequency of the signal
        new_freq (float): The desired frequency
        lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
            but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)

    Returns:
        Tensor: The waveform at the new frequency of dimension (..., time).
    """
    # pack batch
    shape = waveform.size()
    waveform = waveform.view(-1, shape[-1])

    assert orig_freq > 0.0 and new_freq > 0.0

    orig_freq = int(orig_freq)
    new_freq = int(new_freq)
    gcd = math.gcd(orig_freq, new_freq)
    orig_freq = orig_freq // gcd
    new_freq = new_freq // gcd

    kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
                                              waveform.device, waveform.dtype)

    num_wavs, length = waveform.shape
    waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
    resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
    resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
    target_length = int(math.ceil(new_freq * length / orig_freq))
    resampled = resampled[..., :target_length]

    # unpack batch
    resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
    return resampled