functional.py 50.5 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import io
4
import math
moto's avatar
moto committed
5
import warnings
6
from typing import Optional, Tuple
Vincent QB's avatar
Vincent QB committed
7

Jason Lian's avatar
Jason Lian committed
8
import torch
9
from torch import Tensor
Caroline Chen's avatar
Caroline Chen committed
10
from torchaudio._internal import module_utils as _mod_utils
11
import torchaudio
Jason Lian's avatar
Jason Lian committed
12

Jason Lian's avatar
pre  
Jason Lian committed
13
__all__ = [
14
    "spectrogram",
15
    "griffinlim",
16
    "amplitude_to_DB",
17
18
    "DB_to_amplitude",
    "compute_deltas",
moto's avatar
moto committed
19
    "compute_kaldi_pitch",
20
21
    "create_fb_matrix",
    "create_dct",
22
23
24
    "compute_deltas",
    "detect_pitch_frequency",
    "DB_to_amplitude",
25
26
27
28
29
30
    "mu_law_encoding",
    "mu_law_decoding",
    "complex_norm",
    "angle",
    "magphase",
    "phase_vocoder",
31
    'mask_along_axis',
wanglong001's avatar
wanglong001 committed
32
33
    'mask_along_axis_iid',
    'sliding_window_cmn',
34
    "spectral_centroid",
35
    "apply_codec",
36
    "resample",
Jason Lian's avatar
pre  
Jason Lian committed
37
38
]

Vincent QB's avatar
Vincent QB committed
39

40
def spectrogram(
41
42
43
44
45
46
47
        waveform: Tensor,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: Optional[float],
48
49
50
        normalized: bool,
        center: bool = True,
        pad_mode: str = "reflect",
51
52
        onesided: bool = True,
        return_complex: bool = False,
53
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
54
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
55
    The spectrogram can be either magnitude-only or complex.
jamarshon's avatar
jamarshon committed
56
57

    Args:
58
        waveform (Tensor): Tensor of audio of dimension (..., time)
jamarshon's avatar
jamarshon committed
59
        pad (int): Two sided padding of signal
60
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
61
        n_fft (int): Size of FFT
62
63
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
64
        power (float or None): Exponent for the magnitude spectrogram,
jamarshon's avatar
jamarshon committed
65
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
66
            If None, then the complex spectrum is returned instead.
67
        normalized (bool): Whether to normalize by magnitude after stft
68
69
70
71
72
73
74
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
75
76
77
78
79
80
        return_complex (bool, optional):
            ``return_complex = True``, this function returns the resulting Tensor in
            complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
            dimension for real and imaginary parts. (see ``torch.view_as_real``).
            When ``power`` is provided, the value must be False, as the resulting
            Tensor represents real-valued power.
jamarshon's avatar
jamarshon committed
81
82

    Returns:
83
        Tensor: Dimension (..., freq, time), freq is
Vincent QB's avatar
Vincent QB committed
84
        ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
85
        Fourier bins, and time is the number of window hops (n_frame).
Jason Lian's avatar
Jason Lian committed
86
    """
87
88
89
90
91
92
93
94
    if power is None and not return_complex:
        warnings.warn(
            "The use of pseudo complex type in spectrogram is now deprecated."
            "Please migrate to native complex type by providing `return_complex=True`. "
            "Please refer to https://github.com/pytorch/audio/issues/1337 "
            "for more details about torchaudio's plan to migrate to native complex type."
        )

95
96
97
98
    if power is not None and return_complex:
        raise ValueError(
            'When `power` is provided, the return value is real-valued. '
            'Therefore, `return_complex` must be False.')
Jason Lian's avatar
Jason Lian committed
99
100

    if pad > 0:
101
        # TODO add "with torch.no_grad():" back when JIT supports it
102
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
103

104
105
    # pack batch
    shape = waveform.size()
106
    waveform = waveform.reshape(-1, shape[-1])
107

Jason Lian's avatar
Jason Lian committed
108
    # default values are consistent with librosa.core.spectrum._spectrogram
109
110
111
112
113
114
    spec_f = torch.stft(
        input=waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
115
116
        center=center,
        pad_mode=pad_mode,
117
        normalized=False,
118
        onesided=onesided,
119
        return_complex=True,
120
    )
121

122
    # unpack batch
123
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
124

125
    if normalized:
126
        spec_f /= window.pow(2.).sum().sqrt()
127
    if power is not None:
128
129
130
        if power == 1.0:
            return spec_f.abs()
        return spec_f.abs().pow(power)
131
132
133
    if not return_complex:
        return torch.view_as_real(spec_f)
    return spec_f
Jason Lian's avatar
more  
Jason Lian committed
134
135


136
137
138
139
140
141
142
143
144
145
def _get_complex_dtype(real_dtype: torch.dtype):
    if real_dtype == torch.double:
        return torch.cdouble
    if real_dtype == torch.float:
        return torch.cfloat
    if real_dtype == torch.half:
        return torch.complex32
    raise ValueError(f'Unexpected dtype {real_dtype}')


146
def griffinlim(
147
148
149
150
151
152
153
154
155
156
157
        specgram: Tensor,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: float,
        n_iter: int,
        momentum: float,
        length: Optional[int],
        rand_init: bool
) -> Tensor:
158
159
160
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

161
    *  [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
162
163
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.
164
    *  [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
165
166
167
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.
168
    *  [3] D. W. Griffin and J. S. Lim,
169
170
171
172
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
173
        specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
174
            where freq is ``n_fft // 2 + 1``.
175
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
176
177
178
179
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
180
        power (float): Exponent for the magnitude spectrogram,
181
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
182
183
184
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
185
186
187
            Values near 1 can lead to faster convergence, but above 1 may not converge.
        length (int or None): Array length of the expected output.
        rand_init (bool): Initializes phase randomly if True, to zero otherwise.
188
189

    Returns:
Vincent QB's avatar
Vincent QB committed
190
        torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
191
    """
192
193
    assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
    assert momentum >= 0, 'momentum={} < 0'.format(momentum)
194

Vincent QB's avatar
Vincent QB committed
195
196
    # pack batch
    shape = specgram.size()
197
    specgram = specgram.reshape([-1] + list(shape[-2:]))
Vincent QB's avatar
Vincent QB committed
198

Vincent QB's avatar
Vincent QB committed
199
    specgram = specgram.pow(1 / power)
200

201
    # initialize the phase
202
    if rand_init:
203
204
205
        angles = torch.rand(
            specgram.size(),
            dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
206
    else:
207
208
209
        angles = torch.full(
            specgram.size(), 1,
            dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
210
211

    # And initialize the previous iterate to 0
212
    tprev = torch.tensor(0., dtype=specgram.dtype, device=specgram.device)
213
214
    for _ in range(n_iter):
        # Invert with our current estimate of the phases
Jeremy Chen's avatar
Jeremy Chen committed
215
216
217
218
219
        inverse = torch.istft(specgram * angles,
                              n_fft=n_fft,
                              hop_length=hop_length,
                              win_length=win_length,
                              window=window,
220
                              length=length)
221
222

        # Rebuild the spectrogram
223
224
225
226
227
228
229
230
231
232
233
        rebuilt = torch.stft(
            input=inverse,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=True,
            pad_mode='reflect',
            normalized=False,
            onesided=True,
            return_complex=True,
234
        )
235
236

        # Update our phase estimates
237
238
239
        angles = rebuilt
        if momentum:
            angles = angles - tprev.mul_(momentum / (1 + momentum))
240
241
242
243
        angles = angles.div(angles.abs().add(1e-16))

        # Store the previous iterate
        tprev = rebuilt
244
245

    # Return the final phase estimates
Jeremy Chen's avatar
Jeremy Chen committed
246
247
248
249
250
251
    waveform = torch.istft(specgram * angles,
                           n_fft=n_fft,
                           hop_length=hop_length,
                           win_length=win_length,
                           window=window,
                           length=length)
Vincent QB's avatar
Vincent QB committed
252
253

    # unpack batch
254
    waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
Vincent QB's avatar
Vincent QB committed
255
256

    return waveform
257
258


259
260
261
262
263
264
265
def amplitude_to_DB(
        x: Tensor,
        multiplier: float,
        amin: float,
        db_multiplier: float,
        top_db: Optional[float] = None
) -> Tensor:
266
    r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
267

268
269
    The output of each tensor in a batch depends on the maximum value of that tensor,
    and so may return different values for an audio clip split into snippets vs. a full clip.
270
271

    Args:
272
273
274
275

        x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
          the form `(..., freq, time)`. Batched inputs should include a channel dimension and
          have the form `(batch, channel, freq, time)`.
276
        multiplier (float): Use 10. for power and 20. for amplitude
277
        amin (float): Number to clamp ``x``
278
        db_multiplier (float): Log10(max(reference value and amin))
279
        top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
280
            is 80. (Default: ``None``)
281
282

    Returns:
283
        Tensor: Output tensor in decibel scale
284
    """
285
286
    x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
    x_db -= multiplier * db_multiplier
287
288

    if top_db is not None:
289
290
291
292
293
294
295
296
297
        # Expand batch
        shape = x_db.size()
        packed_channels = shape[-3] if x_db.dim() > 2 else 1
        x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])

        x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))

        # Repack batch
        x_db = x_db.reshape(shape)
298

299
    return x_db
300
301


302
303
304
305
306
def DB_to_amplitude(
        x: Tensor,
        ref: float,
        power: float
) -> Tensor:
307
308
309
    r"""Turn a tensor from the decibel scale to the power/amplitude scale.

    Args:
310
        x (Tensor): Input tensor before being converted to power/amplitude scale.
311
312
313
314
        ref (float): Reference which the output will be scaled by.
        power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.

    Returns:
315
        Tensor: Output tensor in power/amplitude scale.
316
317
318
319
    """
    return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)


320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
    r"""Convert Hz to Mels.

    Args:
        freqs (float): Frequencies in Hz
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        mels (float): Frequency in Mels
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 2595.0 * math.log10(1.0 + (freq / 700.0))

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    mels = (freq - f_min) / f_sp

    # Fill in the log-scale part
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    if freq >= min_log_hz:
        mels = min_log_mel + math.log(freq / min_log_hz) / logstep

    return mels


def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
    """Convert mel bin numbers to frequencies.

    Args:
        mels (Tensor): Mel frequencies
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        freqs (Tensor): Mels converted in Hz
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 700.0 * (10.0**(mels / 2595.0) - 1.0)

    # Fill in the linear scale
    f_min = 0.0
    f_sp = 200.0 / 3
    freqs = f_min + f_sp * mels

    # And now the nonlinear scale
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    log_t = (mels >= min_log_mel)
    freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

    return freqs


387
388
389
390
391
def create_fb_matrix(
        n_freqs: int,
        f_min: float,
        f_max: float,
        n_mels: int,
Vincent QB's avatar
Vincent QB committed
392
        sample_rate: int,
393
394
        norm: Optional[str] = None,
        mel_scale: str = "htk",
395
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
396
    r"""Create a frequency bin conversion matrix.
Jason Lian's avatar
more  
Jason Lian committed
397

jamarshon's avatar
jamarshon committed
398
    Args:
399
        n_freqs (int): Number of frequencies to highlight/apply
engineerchuan's avatar
engineerchuan committed
400
401
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
402
        n_mels (int): Number of mel filterbanks
engineerchuan's avatar
engineerchuan committed
403
        sample_rate (int): Sample rate of the audio waveform
404
405
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
406
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Jason Lian's avatar
more  
Jason Lian committed
407

jamarshon's avatar
jamarshon committed
408
    Returns:
409
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
410
411
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
412
413
        size (..., ``n_freqs``), the applied result would be
        ``A * create_fb_matrix(A.size(-1), ...)``.
414
    """
415
416
417
418

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

419
    # freq bins
engineerchuan's avatar
engineerchuan committed
420
421
422
    # Equivalent filterbank construction by Librosa
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

Jason Lian's avatar
more  
Jason Lian committed
423
    # calculate mel freq bins
424
425
426
    m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
    m_max = _hz_to_mel(f_max, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
427
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
428
429
    f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
430
431
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
engineerchuan's avatar
engineerchuan committed
432
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_mels + 2)
Jason Lian's avatar
more  
Jason Lian committed
433
    # create overlapping triangles
434
    zero = torch.zeros(1)
435
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_mels)
436
437
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_mels)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))
Vincent QB's avatar
Vincent QB committed
438

439
    if norm is not None and norm == "slaney":
Vincent QB's avatar
Vincent QB committed
440
441
442
443
        # Slaney-style mel is scaled to be approx constant energy per channel
        enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
        fb *= enorm.unsqueeze(0)

444
445
446
447
448
449
450
    if (fb.max(dim=0).values == 0.).any():
        warnings.warn(
            "At least one mel filterbank has all zero values. "
            f"The value for `n_mels` ({n_mels}) may be set too high. "
            f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
        )

Jason Lian's avatar
more  
Jason Lian committed
451
452
453
    return fb


454
455
456
457
458
def create_dct(
        n_mfcc: int,
        n_mels: int,
        norm: Optional[str]
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
459
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
jamarshon's avatar
jamarshon committed
460
    normalized depending on norm.
Jason Lian's avatar
Jason Lian committed
461

jamarshon's avatar
jamarshon committed
462
    Args:
463
464
        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
465
        norm (str or None): Norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
466

jamarshon's avatar
jamarshon committed
467
    Returns:
468
        Tensor: The transformation matrix, to be right-multiplied to
469
        row-wise data of size (``n_mels``, ``n_mfcc``).
Jason Lian's avatar
more  
Jason Lian committed
470
471
    """
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
472
473
474
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
475
476
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
477
    else:
478
        assert norm == "ortho"
479
        dct[0] *= 1.0 / math.sqrt(2.0)
480
        dct *= math.sqrt(2.0 / float(n_mels))
481
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
482
483


484
485
486
487
def mu_law_encoding(
        x: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
488
    r"""Encode signal based on mu-law companding.  For more info see the
Jason Lian's avatar
Jason Lian committed
489
490
491
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
jamarshon's avatar
jamarshon committed
492
    returns a signal encoded with values from 0 to quantization_channels - 1.
Jason Lian's avatar
Jason Lian committed
493

jamarshon's avatar
jamarshon committed
494
    Args:
495
        x (Tensor): Input tensor
496
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
497

jamarshon's avatar
jamarshon committed
498
    Returns:
499
        Tensor: Input after mu-law encoding
Jason Lian's avatar
Jason Lian committed
500
    """
501
    mu = quantization_channels - 1.0
502
    if not x.is_floating_point():
503
504
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
505
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
506
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
507
508
509
    return x_mu


510
511
512
513
def mu_law_decoding(
        x_mu: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
514
    r"""Decode mu-law encoded signal.  For more info see the
Jason Lian's avatar
Jason Lian committed
515
516
517
518
519
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

jamarshon's avatar
jamarshon committed
520
    Args:
521
        x_mu (Tensor): Input tensor
522
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
523

jamarshon's avatar
jamarshon committed
524
    Returns:
525
        Tensor: Input after mu-law decoding
Jason Lian's avatar
Jason Lian committed
526
    """
527
    mu = quantization_channels - 1.0
528
    if not x_mu.is_floating_point():
529
530
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
531
532
    x = ((x_mu) / mu) * 2 - 1.0
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
Jason Lian's avatar
more  
Jason Lian committed
533
    return x
534
535


536
537
538
539
540
541
@_mod_utils.deprecated(
    "Please convert the input Tensor to complex type with `torch.view_as_complex` then "
    "use `torch.abs`. "
    "Please refer to https://github.com/pytorch/audio/issues/1337 "
    "for more details about torchaudio's plan to migrate to native complex type."
)
542
543
544
545
def complex_norm(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tensor:
546
    r"""Compute the norm of complex tensor input.
547
548

    Args:
549
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
550
        power (float): Power of the norm. (Default: `1.0`).
551
552

    Returns:
553
        Tensor: Power of the normed input tensor. Shape of `(..., )`
554
    """
555
556
557
558

    # Replace by torch.norm once issue is fixed
    # https://github.com/pytorch/pytorch/issues/34279
    return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
559
560


561
562
563
564
565
566
@_mod_utils.deprecated(
    "Please convert the input Tensor to complex type with `torch.view_as_complex` then "
    "use `torch.angle`. "
    "Please refer to https://github.com/pytorch/audio/issues/1337 "
    "for more details about torchaudio's plan to migrate to native complex type."
)
567
568
569
def angle(
        complex_tensor: Tensor
) -> Tensor:
570
571
572
    r"""Compute the angle of complex tensor input.

    Args:
573
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
574
575

    Return:
576
        Tensor: Angle of a complex tensor. Shape of `(..., )`
577
578
579
580
    """
    return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])


581
582
583
584
def magphase(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tuple[Tensor, Tensor]:
585
    r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
586
587

    Args:
588
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
589
590
591
        power (float): Power of the norm. (Default: `1.0`)

    Returns:
592
        (Tensor, Tensor): The magnitude and phase of the complex tensor
593
594
595
596
597
598
    """
    mag = complex_norm(complex_tensor, power)
    phase = angle(complex_tensor)
    return mag, phase


599
600
601
602
603
def phase_vocoder(
        complex_specgrams: Tensor,
        rate: float,
        phase_advance: Tensor
) -> Tensor:
604
    r"""Given a STFT tensor, speed up in time without modifying pitch by a
605
    factor of ``rate``.
Vincent QB's avatar
Vincent QB committed
606

607
    Args:
608
609
610
        complex_specgrams (Tensor):
            Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
            or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
611
        rate (float): Speed-up factor
612
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Vincent QB's avatar
Vincent QB committed
613

614
    Returns:
615
616
617
        Tensor:
            Stretched spectrogram. The resulting tensor is of the same dtype as the input
            spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Vincent QB's avatar
Vincent QB committed
618

619
620
621
622
623
624
625
626
627
628
629
630
    Example - With Tensor of complex dtype
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time)
        >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
        >>> rate = 1.3 # Speed up by 30%
        >>> phase_advance = torch.linspace(
        >>>    0, math.pi * hop_length, freq)[..., None]
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
        torch.Size([2, 1025, 231])

    Example - With Tensor of real dtype and extra dimension for complex field
631
632
633
634
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time, complex=2)
        >>> complex_specgrams = torch.randn(2, freq, 300, 2)
        >>> rate = 1.3 # Speed up by 30%
635
        >>> phase_advance = torch.linspace(
636
        >>>    0, math.pi * hop_length, freq)[..., None]
637
638
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
639
        torch.Size([2, 1025, 231, 2])
640
    """
641
642
643
    if rate == 1.0:
        return complex_specgrams

644
645
646
647
648
649
650
651
652
653
654
655
656
    if not complex_specgrams.is_complex():
        warnings.warn(
            "The use of pseudo complex type in `torchaudio.functional.phase_vocoder` and "
            "`torchaudio.transforms.TimeStretch` is now deprecated."
            "Please migrate to native complex type by converting the input tensor with "
            "`torch.view_as_complex`. "
            "Please refer to https://github.com/pytorch/audio/issues/1337 "
            "for more details about torchaudio's plan to migrate to native complex type."
        )
        if complex_specgrams.size(-1) != 2:
            raise ValueError(
                "complex_specgrams must be either native complex tensors or "
                "real valued tensors with shape (..., 2)")
657
658
659
660
661

    is_complex = complex_specgrams.is_complex()

    if not is_complex:
        complex_specgrams = torch.view_as_complex(complex_specgrams)
662

663
664
    # pack batch
    shape = complex_specgrams.size()
665
666
667
668
669
670
671
672
673
674
675
    complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))

    # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
    # Note torch.real is a view so it does not incur any memory copy.
    real_dtype = torch.real(complex_specgrams).dtype
    time_steps = torch.arange(
        0,
        complex_specgrams.size(-1),
        rate,
        device=complex_specgrams.device,
        dtype=real_dtype)
676

677
    alphas = time_steps % 1.0
678
    phase_0 = complex_specgrams[..., :1].angle()
679
680

    # Time Padding
681
    complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
682

683
    # (new_bins, freq, 2)
684
685
    complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
    complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
686

687
688
    angle_0 = complex_specgrams_0.angle()
    angle_1 = complex_specgrams_1.angle()
689

690
691
    norm_0 = complex_specgrams_0.abs()
    norm_1 = complex_specgrams_1.abs()
692
693
694
695
696
697

    phase = angle_1 - angle_0 - phase_advance
    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

    # Compute Phase Accum
    phase = phase + phase_advance
Vincent QB's avatar
Vincent QB committed
698
    phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
699
700
701
702
    phase_acc = torch.cumsum(phase, -1)

    mag = alphas * norm_1 + (1 - alphas) * norm_0

703
    complex_specgrams_stretch = torch.polar(mag, phase_acc)
704

705
    # unpack batch
706
    complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
707

708
709
    if not is_complex:
        return torch.view_as_real(complex_specgrams_stretch)
710
    return complex_specgrams_stretch
711
712


713
714
715
716
717
718
def mask_along_axis_iid(
        specgrams: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
719
720
721
722
723
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.

    Args:
Vincent QB's avatar
Vincent QB committed
724
        specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
725
726
727
728
729
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)

    Returns:
730
        Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
731
732
733
734
735
    """

    if axis != 2 and axis != 3:
        raise ValueError('Only Frequency and Time masking are supported')

736
737
738
739
740
    device = specgrams.device
    dtype = specgrams.dtype

    value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
    min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
741
742

    # Create broadcastable mask
743
744
745
    mask_start = min_value[..., None, None]
    mask_end = (min_value + value)[..., None, None]
    mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
746
747
748

    # Per batch example masking
    specgrams = specgrams.transpose(axis, -1)
749
    specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
750
751
752
753
754
    specgrams = specgrams.transpose(axis, -1)

    return specgrams


755
756
757
758
759
760
def mask_along_axis(
        specgram: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
761
762
763
764
765
766
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
    All examples will have the same mask interval.

    Args:
Vincent QB's avatar
Vincent QB committed
767
        specgram (Tensor): Real spectrogram (channel, freq, time)
768
769
770
771
772
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

    Returns:
773
        Tensor: Masked spectrogram of dimensions (channel, freq, time)
774
    """
775
776
    if axis != 1 and axis != 2:
        raise ValueError('Only Frequency and Time masking are supported')
777

778
779
    # pack batch
    shape = specgram.size()
780
    specgram = specgram.reshape([-1] + list(shape[-2:]))
781
782
783
784
785
    value = torch.rand(1) * mask_param
    min_value = torch.rand(1) * (specgram.size(axis) - value)

    mask_start = (min_value.long()).squeeze()
    mask_end = (min_value.long() + value.long()).squeeze()
786
787
788
789
    mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype)
    mask = (mask >= mask_start) & (mask < mask_end)
    if axis == 1:
        mask = mask.unsqueeze(-1)
790
791

    assert mask_end - mask_start < mask_param
792
793

    specgram = specgram.masked_fill(mask, mask_value)
794

795
    # unpack batch
796
    specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
797

Vincent QB's avatar
Vincent QB committed
798
    return specgram
799
800


801
802
803
804
805
def compute_deltas(
        specgram: Tensor,
        win_length: int = 5,
        mode: str = "replicate"
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
806
807
808
    r"""Compute delta coefficients of a tensor, usually a spectrogram:

    .. math::
809
       d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
Vincent QB's avatar
Vincent QB committed
810
811
812

    where :math:`d_t` is the deltas at time :math:`t`,
    :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
813
    :math:`N` is ``(win_length-1)//2``.
Vincent QB's avatar
Vincent QB committed
814
815

    Args:
816
817
818
        specgram (Tensor): Tensor of audio of dimension (..., freq, time)
        win_length (int, optional): The window length used for computing delta (Default: ``5``)
        mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Vincent QB's avatar
Vincent QB committed
819
820

    Returns:
821
        Tensor: Tensor of deltas of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
822
823
824
825
826
827

    Example
        >>> specgram = torch.randn(1, 40, 1000)
        >>> delta = compute_deltas(specgram)
        >>> delta2 = compute_deltas(delta)
    """
828
829
    device = specgram.device
    dtype = specgram.dtype
Vincent QB's avatar
Vincent QB committed
830

Vincent QB's avatar
Vincent QB committed
831
832
    # pack batch
    shape = specgram.size()
833
    specgram = specgram.reshape(1, -1, shape[-1])
Vincent QB's avatar
Vincent QB committed
834

Vincent QB's avatar
Vincent QB committed
835
836
837
838
839
840
841
842
843
    assert win_length >= 3

    n = (win_length - 1) // 2

    # twice sum of integer squared
    denom = n * (n + 1) * (2 * n + 1) / 3

    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)

844
    kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
Vincent QB's avatar
Vincent QB committed
845

846
    output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
Vincent QB's avatar
Vincent QB committed
847
848

    # unpack batch
849
    output = output.reshape(shape)
Vincent QB's avatar
Vincent QB committed
850
851

    return output
Vincent QB's avatar
Vincent QB committed
852
853


854
855
856
857
858
859
def _compute_nccf(
        waveform: Tensor,
        sample_rate: int,
        frame_time: float,
        freq_low: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
860
861
862
863
864
865
866
867
868
    r"""
    Compute Normalized Cross-Correlation Function (NCCF).

    .. math::
        \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},

    where
    :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
    :math:`w` is the waveform,
869
    :math:`N` is the length of a frame,
Vincent QB's avatar
Vincent QB committed
870
871
872
873
874
875
876
    :math:`b_i` is the beginning of frame :math:`i`,
    :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
    """

    EPSILON = 10 ** (-9)

    # Number of lags to check
877
    lags = int(math.ceil(sample_rate / freq_low))
Vincent QB's avatar
Vincent QB committed
878
879
880
881

    frame_size = int(math.ceil(sample_rate * frame_time))

    waveform_length = waveform.size()[-1]
882
    num_of_frames = int(math.ceil(waveform_length / frame_size))
Vincent QB's avatar
Vincent QB committed
883
884
885
886
887
888
889

    p = lags + num_of_frames * frame_size - waveform_length
    waveform = torch.nn.functional.pad(waveform, (0, p))

    # Compute lags
    output_lag = []
    for lag in range(1, lags + 1):
890
891
        s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
        s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
Vincent QB's avatar
Vincent QB committed
892
893
894
895
896
897
898
899
900
901
902
903
904
905

        output_frames = (
            (s1 * s2).sum(-1)
            / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
            / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
        )

        output_lag.append(output_frames.unsqueeze(-1))

    nccf = torch.cat(output_lag, -1)

    return nccf


906
907
908
909
910
def _combine_max(
        a: Tuple[Tensor, Tensor],
        b: Tuple[Tensor, Tensor],
        thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
Vincent QB's avatar
Vincent QB committed
911
912
913
914
915
916
917
918
919
    """
    Take value from first if bigger than a multiplicative factor of the second, elementwise.
    """
    mask = (a[0] > thresh * b[0])
    values = mask * a[0] + ~mask * b[0]
    indices = mask * a[1] + ~mask * b[1]
    return values, indices


920
921
922
923
924
def _find_max_per_frame(
        nccf: Tensor,
        sample_rate: int,
        freq_high: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
925
926
927
928
929
930
931
932
    r"""
    For each frame, take the highest value of NCCF,
    apply centered median smoothing, and convert to frequency.

    Note: If the max among all the lags is very close
    to the first half of lags, then the latter is taken.
    """

933
    lag_min = int(math.ceil(sample_rate / freq_high))
Vincent QB's avatar
Vincent QB committed
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952

    # Find near enough max that is smallest

    best = torch.max(nccf[..., lag_min:], -1)

    half_size = nccf.shape[-1] // 2
    half = torch.max(nccf[..., lag_min:half_size], -1)

    best = _combine_max(half, best)
    indices = best[1]

    # Add back minimal lag
    indices += lag_min
    # Add 1 empirical calibration offset
    indices += 1

    return indices


953
954
955
956
def _median_smoothing(
        indices: Tensor,
        win_length: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
    r"""
    Apply median smoothing to the 1D tensor over the given window.
    """

    # Centered windowed
    pad_length = (win_length - 1) // 2

    # "replicate" padding in any dimension
    indices = torch.nn.functional.pad(
        indices, (pad_length, 0), mode="constant", value=0.
    )

    indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
    roll = indices.unfold(-1, win_length, 1)

    values, _ = torch.median(roll, -1)
    return values


def detect_pitch_frequency(
977
978
979
980
981
982
983
        waveform: Tensor,
        sample_rate: int,
        frame_time: float = 10 ** (-2),
        win_length: int = 30,
        freq_low: int = 85,
        freq_high: int = 3400,
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
984
985
986
987
988
    r"""Detect pitch frequency.

    It is implemented using normalized cross-correlation function and median smoothing.

    Args:
989
        waveform (Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
990
        sample_rate (int): The sample rate of the waveform (Hz)
991
992
993
994
        frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
        win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
        freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
        freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Vincent QB's avatar
Vincent QB committed
995
996

    Returns:
997
        Tensor: Tensor of freq of dimension (..., frame)
Vincent QB's avatar
Vincent QB committed
998
    """
Vincent QB's avatar
Vincent QB committed
999
    # pack batch
1000
    shape = list(waveform.size())
1001
    waveform = waveform.reshape([-1] + shape[-1:])
Vincent QB's avatar
Vincent QB committed
1002

Vincent QB's avatar
Vincent QB committed
1003
1004
1005
1006
1007
1008
1009
1010
    nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
    indices = _find_max_per_frame(nccf, sample_rate, freq_high)
    indices = _median_smoothing(indices, win_length)

    # Convert indices to frequency
    EPSILON = 10 ** (-9)
    freq = sample_rate / (EPSILON + indices.to(torch.float))

Vincent QB's avatar
Vincent QB committed
1011
    # unpack batch
1012
    freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
Vincent QB's avatar
Vincent QB committed
1013

Vincent QB's avatar
Vincent QB committed
1014
    return freq
wanglong001's avatar
wanglong001 committed
1015
1016
1017


def sliding_window_cmn(
1018
    specgram: Tensor,
wanglong001's avatar
wanglong001 committed
1019
1020
1021
1022
1023
1024
1025
1026
1027
    cmn_window: int = 600,
    min_cmn_window: int = 100,
    center: bool = False,
    norm_vars: bool = False,
) -> Tensor:
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
1028
        specgram (Tensor): Tensor of audio of dimension (..., time, freq)
wanglong001's avatar
wanglong001 committed
1029
1030
1031
1032
1033
1034
1035
1036
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

    Returns:
1037
        Tensor: Tensor matching input shape (..., freq, time)
wanglong001's avatar
wanglong001 committed
1038
    """
1039
    input_shape = specgram.shape
1040
    num_frames, num_feats = input_shape[-2:]
1041
1042
    specgram = specgram.view(-1, num_frames, num_feats)
    num_channels = specgram.shape[0]
1043

1044
1045
    dtype = specgram.dtype
    device = specgram.device
wanglong001's avatar
wanglong001 committed
1046
    last_window_start = last_window_end = -1
1047
1048
    cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
    cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
1049
    cmn_specgram = torch.zeros(
1050
        num_channels, num_frames, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
    for t in range(num_frames):
        window_start = 0
        window_end = 0
        if center:
            window_start = t - cmn_window // 2
            window_end = window_start + cmn_window
        else:
            window_start = t - cmn_window
            window_end = t + 1
        if window_start < 0:
            window_end -= window_start
            window_start = 0
        if not center:
            if window_end > t:
                window_end = max(t + 1, min_cmn_window)
        if window_end > num_frames:
            window_start -= (window_end - num_frames)
            window_end = num_frames
            if window_start < 0:
                window_start = 0
        if last_window_start == -1:
1072
            input_part = specgram[:, window_start: window_end - window_start, :]
1073
            cur_sum += torch.sum(input_part, 1)
wanglong001's avatar
wanglong001 committed
1074
            if norm_vars:
1075
                cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
wanglong001's avatar
wanglong001 committed
1076
1077
        else:
            if window_start > last_window_start:
1078
                frame_to_remove = specgram[:, last_window_start, :]
wanglong001's avatar
wanglong001 committed
1079
1080
1081
1082
                cur_sum -= frame_to_remove
                if norm_vars:
                    cur_sumsq -= (frame_to_remove ** 2)
            if window_end > last_window_end:
1083
                frame_to_add = specgram[:, last_window_end, :]
wanglong001's avatar
wanglong001 committed
1084
1085
1086
1087
1088
1089
                cur_sum += frame_to_add
                if norm_vars:
                    cur_sumsq += (frame_to_add ** 2)
        window_frames = window_end - window_start
        last_window_start = window_start
        last_window_end = window_end
1090
        cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
wanglong001's avatar
wanglong001 committed
1091
1092
        if norm_vars:
            if window_frames == 1:
1093
                cmn_specgram[:, t, :] = torch.zeros(
1094
                    num_channels, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1095
1096
1097
1098
1099
            else:
                variance = cur_sumsq
                variance = variance / window_frames
                variance -= ((cur_sum ** 2) / (window_frames ** 2))
                variance = torch.pow(variance, -0.5)
1100
                cmn_specgram[:, t, :] *= variance
1101

1102
    cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
1103
    if len(input_shape) == 2:
1104
1105
        cmn_specgram = cmn_specgram.squeeze(0)
    return cmn_specgram
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140


def spectral_centroid(
        waveform: Tensor,
        sample_rate: int,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
) -> Tensor:
    r"""
    Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        waveform (Tensor): Tensor of audio of dimension (..., time)
        sample_rate (int): Sample rate of the audio waveform
        pad (int): Two sided padding of signal
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size

    Returns:
        Tensor: Dimension (..., time)
    """
    specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
                           win_length=win_length, power=1., normalized=False)
    freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
                           device=specgram.device).reshape((-1, 1))
    freq_dim = -2
    return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
moto's avatar
moto committed
1141
1142


Caroline Chen's avatar
Caroline Chen committed
1143
@_mod_utils.requires_sox()
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
def apply_codec(
    waveform: Tensor,
    sample_rate: int,
    format: str,
    channels_first: bool = True,
    compression: Optional[float] = None,
    encoding: Optional[str] = None,
    bits_per_sample: Optional[int] = None,
) -> Tensor:
    r"""
Vincent QB's avatar
Vincent QB committed
1154
1155
    Apply codecs as a form of augmentation.

1156
    Args:
Vincent QB's avatar
Vincent QB committed
1157
1158
1159
        waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```.
        sample_rate (int): Sample rate of the audio waveform.
        format (str): File format.
1160
1161
1162
1163
        channels_first (bool):
            When True, both the input and output Tensor have dimension ``[channel, time]``.
            Otherwise, they have dimension ``[time, channel]``.
        compression (float): Used for formats other than WAV.
Matthew Turnshek's avatar
Matthew Turnshek committed
1164
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1165
        encoding (str, optional): Changes the encoding for the supported formats.
Vincent QB's avatar
Vincent QB committed
1166
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1167
        bits_per_sample (int, optional): Changes the bit depth for the supported formats.
Vincent QB's avatar
Vincent QB committed
1168
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1169
1170
1171

    Returns:
        torch.Tensor: Resulting Tensor.
Vincent QB's avatar
Vincent QB committed
1172
        If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``.
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
    """
    bytes = io.BytesIO()
    torchaudio.backend.sox_io_backend.save(bytes,
                                           waveform,
                                           sample_rate,
                                           channels_first,
                                           compression,
                                           format,
                                           encoding,
                                           bits_per_sample
                                           )
    bytes.seek(0)
    augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
        bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
    return augmented


1190
@_mod_utils.requires_kaldi()
moto's avatar
moto committed
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
def compute_kaldi_pitch(
        waveform: torch.Tensor,
        sample_rate: float,
        frame_length: float = 25.0,
        frame_shift: float = 10.0,
        min_f0: float = 50,
        max_f0: float = 400,
        soft_min_f0: float = 10.0,
        penalty_factor: float = 0.1,
        lowpass_cutoff: float = 1000,
        resample_frequency: float = 4000,
        delta_pitch: float = 0.005,
        nccf_ballast: float = 7000,
        lowpass_filter_width: int = 1,
        upsample_filter_width: int = 5,
        max_frames_latency: int = 0,
        frames_per_chunk: int = 0,
        simulate_first_pass_online: bool = False,
        recompute_frame: int = 500,
        snip_edges: bool = True,
) -> torch.Tensor:
    """Extract pitch based on method described in [1].

    This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.

    Args:
        waveform (Tensor):
            The input waveform of shape `(..., time)`.
        sample_rate (float):
            Sample rate of `waveform`.
        frame_length (float, optional):
moto's avatar
moto committed
1222
            Frame length in milliseconds. (default: 25.0)
moto's avatar
moto committed
1223
        frame_shift (float, optional):
moto's avatar
moto committed
1224
            Frame shift in milliseconds. (default: 10.0)
moto's avatar
moto committed
1225
        min_f0 (float, optional):
moto's avatar
moto committed
1226
            Minimum F0 to search for (Hz)  (default: 50.0)
moto's avatar
moto committed
1227
        max_f0 (float, optional):
moto's avatar
moto committed
1228
            Maximum F0 to search for (Hz)  (default: 400.0)
moto's avatar
moto committed
1229
        soft_min_f0 (float, optional):
moto's avatar
moto committed
1230
            Minimum f0, applied in soft way, must not exceed min-f0  (default: 10.0)
moto's avatar
moto committed
1231
        penalty_factor (float, optional):
moto's avatar
moto committed
1232
            Cost factor for FO change.  (default: 0.1)
moto's avatar
moto committed
1233
        lowpass_cutoff (float, optional):
moto's avatar
moto committed
1234
            Cutoff frequency for LowPass filter (Hz) (default: 1000)
moto's avatar
moto committed
1235
1236
        resample_frequency (float, optional):
            Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
moto's avatar
moto committed
1237
            (default: 4000)
moto's avatar
moto committed
1238
        delta_pitch( float, optional):
moto's avatar
moto committed
1239
            Smallest relative change in pitch that our algorithm measures. (default: 0.005)
moto's avatar
moto committed
1240
        nccf_ballast (float, optional):
moto's avatar
moto committed
1241
            Increasing this factor reduces NCCF for quiet frames (default: 7000)
moto's avatar
moto committed
1242
1243
        lowpass_filter_width (int, optional):
            Integer that determines filter width of lowpass filter, more gives sharper filter.
moto's avatar
moto committed
1244
            (default: 1)
moto's avatar
moto committed
1245
        upsample_filter_width (int, optional):
moto's avatar
moto committed
1246
            Integer that determines filter width when upsampling NCCF. (default: 5)
moto's avatar
moto committed
1247
1248
1249
        max_frames_latency (int, optional):
            Maximum number of frames of latency that we allow pitch tracking to introduce into
            the feature processing (affects output only if ``frames_per_chunk > 0`` and
moto's avatar
moto committed
1250
            ``simulate_first_pass_online=True``) (default: 0)
moto's avatar
moto committed
1251
        frames_per_chunk (int, optional):
moto's avatar
moto committed
1252
            The number of frames used for energy normalization. (default: 0)
moto's avatar
moto committed
1253
1254
1255
        simulate_first_pass_online (bool, optional):
            If true, the function will output features that correspond to what an online decoder
            would see in the first pass of decoding -- not the final version of the features,
moto's avatar
moto committed
1256
            which is the default. (default: False)
moto's avatar
moto committed
1257
1258
1259
1260
1261
            Relevant if ``frames_per_chunk > 0``.
        recompute_frame (int, optional):
            Only relevant for compatibility with online pitch extraction.
            A non-critical parameter; the frame at which we recompute some of the forward pointers,
            after revising our estimate of the signal energy.
moto's avatar
moto committed
1262
            Relevant if ``frames_per_chunk > 0``. (default: 500)
moto's avatar
moto committed
1263
1264
1265
        snip_edges (bool, optional):
            If this is set to false, the incomplete frames near the ending edge won't be snipped,
            so that the number of frames is the file size divided by the frame-shift.
moto's avatar
moto committed
1266
            This makes different types of features give the same number of frames. (default: True)
moto's avatar
moto committed
1267
1268

    Returns:
moto's avatar
moto committed
1269
1270
       Tensor: Pitch feature. Shape: ``(batch, frames 2)`` where the last dimension
       corresponds to pitch and NCCF.
moto's avatar
moto committed
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292

    Reference:
        - A pitch extraction algorithm tuned for automatic speech recognition

          P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur

          2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),

          Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049.
    """
    shape = waveform.shape
    waveform = waveform.reshape(-1, shape[-1])
    result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
        waveform, sample_rate, frame_length, frame_shift,
        min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff,
        resample_frequency, delta_pitch, nccf_ballast,
        lowpass_filter_width, upsample_filter_width, max_frames_latency,
        frames_per_chunk, simulate_first_pass_online, recompute_frame,
        snip_edges,
    )
    result = result.reshape(shape[:-1] + result.shape[-2:])
    return result
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
                              device: torch.device, dtype: torch.dtype):
    assert lowpass_filter_width > 0
    kernels = []
    base_freq = min(orig_freq, new_freq)
    # This will perform antialiasing filtering by removing the highest frequencies.
    # At first I thought I only needed this when downsampling, but when upsampling
    # you will get edge artifacts without this, as the edge is equivalent to zero padding,
    # which will add high freq artifacts.
    base_freq *= 0.99

    # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
    # using the sinc interpolation formula:
    #   x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
    # We can then sample the function x(t) with a different sample rate:
    #    y[j] = x(j / new_freq)
    # or,
    #    y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

    # We see here that y[j] is the convolution of x[i] with a specific filter, for which
    # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
    # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
    # Indeed:
    # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
    #                 = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
    #                 = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
    # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
    # This will explain the F.conv1d after, with a stride of orig_freq.
    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
    # they will have a lot of almost zero values to the left or to the right...
    # There is probably a way to evaluate those filters more efficiently, but this is kept for
    # future work.
    idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

    for i in range(new_freq):
        t = (-i / new_freq + idx / orig_freq) * base_freq
        t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
        t *= math.pi
        # we do not use torch.hann_window here as we need to evaluate the window
        # at specific positions, not over a regular grid.
        window = torch.cos(t / lowpass_filter_width / 2)**2
        kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
        kernel.mul_(window)
        kernels.append(kernel)

    scale = base_freq / orig_freq
    return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample(
        waveform: Tensor,
        orig_freq: float,
        new_freq: float,
        lowpass_filter_width: int = 6
) -> Tensor:
    r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
    which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
    a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
    the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
    upsample/downsample the signal.

    https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
    https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56

    Args:
        waveform (Tensor): The input signal of dimension (..., time)
        orig_freq (float): The original frequency of the signal
        new_freq (float): The desired frequency
        lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
            but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)

    Returns:
        Tensor: The waveform at the new frequency of dimension (..., time).
    """
    # pack batch
    shape = waveform.size()
    waveform = waveform.view(-1, shape[-1])

    assert orig_freq > 0.0 and new_freq > 0.0

    orig_freq = int(orig_freq)
    new_freq = int(new_freq)
    gcd = math.gcd(orig_freq, new_freq)
    orig_freq = orig_freq // gcd
    new_freq = new_freq // gcd

    kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
                                              waveform.device, waveform.dtype)

    num_wavs, length = waveform.shape
    waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
    resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
    resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
    target_length = int(math.ceil(new_freq * length / orig_freq))
    resampled = resampled[..., :target_length]

    # unpack batch
    resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
    return resampled