functional.py 48 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import io
4
import math
moto's avatar
moto committed
5
import warnings
6
from typing import Optional, Tuple
Vincent QB's avatar
Vincent QB committed
7

Jason Lian's avatar
Jason Lian committed
8
import torch
9
from torch import Tensor
Caroline Chen's avatar
Caroline Chen committed
10
from torchaudio._internal import module_utils as _mod_utils
11
import torchaudio
Jason Lian's avatar
Jason Lian committed
12

Jason Lian's avatar
pre  
Jason Lian committed
13
__all__ = [
14
    "spectrogram",
15
    "griffinlim",
16
    "amplitude_to_DB",
17
18
    "DB_to_amplitude",
    "compute_deltas",
moto's avatar
moto committed
19
    "compute_kaldi_pitch",
20
21
    "create_fb_matrix",
    "create_dct",
22
23
24
    "compute_deltas",
    "detect_pitch_frequency",
    "DB_to_amplitude",
25
26
27
28
29
30
    "mu_law_encoding",
    "mu_law_decoding",
    "complex_norm",
    "angle",
    "magphase",
    "phase_vocoder",
31
    'mask_along_axis',
wanglong001's avatar
wanglong001 committed
32
33
    'mask_along_axis_iid',
    'sliding_window_cmn',
34
    "spectral_centroid",
35
    "apply_codec",
36
    "resample",
Jason Lian's avatar
pre  
Jason Lian committed
37
38
]

Vincent QB's avatar
Vincent QB committed
39

40
def spectrogram(
41
42
43
44
45
46
47
        waveform: Tensor,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: Optional[float],
48
49
50
51
        normalized: bool,
        center: bool = True,
        pad_mode: str = "reflect",
        onesided: bool = True
52
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
53
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
54
    The spectrogram can be either magnitude-only or complex.
jamarshon's avatar
jamarshon committed
55
56

    Args:
57
        waveform (Tensor): Tensor of audio of dimension (..., time)
jamarshon's avatar
jamarshon committed
58
        pad (int): Two sided padding of signal
59
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
60
        n_fft (int): Size of FFT
61
62
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
63
        power (float or None): Exponent for the magnitude spectrogram,
jamarshon's avatar
jamarshon committed
64
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
65
            If None, then the complex spectrum is returned instead.
66
        normalized (bool): Whether to normalize by magnitude after stft
67
68
69
70
71
72
73
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
jamarshon's avatar
jamarshon committed
74
75

    Returns:
76
        Tensor: Dimension (..., freq, time), freq is
Vincent QB's avatar
Vincent QB committed
77
        ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
78
        Fourier bins, and time is the number of window hops (n_frame).
Jason Lian's avatar
Jason Lian committed
79
    """
Jason Lian's avatar
Jason Lian committed
80
81

    if pad > 0:
82
        # TODO add "with torch.no_grad():" back when JIT supports it
83
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
84

85
86
    # pack batch
    shape = waveform.size()
87
    waveform = waveform.reshape(-1, shape[-1])
88

Jason Lian's avatar
Jason Lian committed
89
    # default values are consistent with librosa.core.spectrum._spectrogram
90
91
92
93
94
95
    spec_f = torch.stft(
        input=waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
96
97
        center=center,
        pad_mode=pad_mode,
98
        normalized=False,
99
        onesided=onesided,
100
        return_complex=True,
101
    )
102

103
    # unpack batch
104
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
105

106
    if normalized:
107
        spec_f /= window.pow(2.).sum().sqrt()
108
    if power is not None:
109
110
111
112
        if power == 1.0:
            return spec_f.abs()
        return spec_f.abs().pow(power)
    return torch.view_as_real(spec_f)
Jason Lian's avatar
more  
Jason Lian committed
113
114


115
def griffinlim(
116
117
118
119
120
121
122
123
124
125
126
        specgram: Tensor,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
        power: float,
        n_iter: int,
        momentum: float,
        length: Optional[int],
        rand_init: bool
) -> Tensor:
127
128
129
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
        Implementation ported from `librosa`.

130
    *  [1] McFee, Brian, Colin Raffel, Dawen Liang, Daniel PW Ellis, Matt McVicar, Eric Battenberg, and Oriol Nieto.
131
132
        "librosa: Audio and music signal analysis in python."
        In Proceedings of the 14th python in science conference, pp. 18-25. 2015.
133
    *  [2] Perraudin, N., Balazs, P., & Søndergaard, P. L.
134
135
136
        "A fast Griffin-Lim algorithm,"
        IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (pp. 1-4),
        Oct. 2013.
137
    *  [3] D. W. Griffin and J. S. Lim,
138
139
140
141
        "Signal estimation from modified short-time Fourier transform,"
        IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.

    Args:
142
        specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
143
            where freq is ``n_fft // 2 + 1``.
144
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
145
146
147
148
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
149
        power (float): Exponent for the magnitude spectrogram,
150
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
151
152
153
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
154
155
156
            Values near 1 can lead to faster convergence, but above 1 may not converge.
        length (int or None): Array length of the expected output.
        rand_init (bool): Initializes phase randomly if True, to zero otherwise.
157
158

    Returns:
Vincent QB's avatar
Vincent QB committed
159
        torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
160
    """
161
162
    assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
    assert momentum >= 0, 'momentum={} < 0'.format(momentum)
163

Vincent QB's avatar
Vincent QB committed
164
165
    # pack batch
    shape = specgram.size()
166
    specgram = specgram.reshape([-1] + list(shape[-2:]))
Vincent QB's avatar
Vincent QB committed
167

Vincent QB's avatar
Vincent QB committed
168
    specgram = specgram.pow(1 / power)
169
170

    # randomly initialize the phase
Vincent QB's avatar
Vincent QB committed
171
    batch, freq, frames = specgram.size()
172
173
174
175
176
    if rand_init:
        angles = 2 * math.pi * torch.rand(batch, freq, frames)
    else:
        angles = torch.zeros(batch, freq, frames)
    angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
177
        .to(dtype=specgram.dtype, device=specgram.device)
Vincent QB's avatar
Vincent QB committed
178
    specgram = specgram.unsqueeze(-1).expand_as(angles)
179
180
181
182
183
184
185
186
187

    # And initialize the previous iterate to 0
    rebuilt = torch.tensor(0.)

    for _ in range(n_iter):
        # Store the previous iterate
        tprev = rebuilt

        # Invert with our current estimate of the phases
Jeremy Chen's avatar
Jeremy Chen committed
188
189
190
191
192
193
        inverse = torch.istft(specgram * angles,
                              n_fft=n_fft,
                              hop_length=hop_length,
                              win_length=win_length,
                              window=window,
                              length=length).float()
194
195

        # Rebuild the spectrogram
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        rebuilt = torch.view_as_real(
            torch.stft(
                input=inverse,
                n_fft=n_fft,
                hop_length=hop_length,
                win_length=win_length,
                window=window,
                center=True,
                pad_mode='reflect',
                normalized=False,
                onesided=True,
                return_complex=True,
            )
        )
210
211

        # Update our phase estimates
212
213
214
        angles = rebuilt
        if momentum:
            angles = angles - tprev.mul_(momentum / (1 + momentum))
215
        angles = angles.div(complex_norm(angles).add(1e-16).unsqueeze(-1).expand_as(angles))
216
217

    # Return the final phase estimates
Jeremy Chen's avatar
Jeremy Chen committed
218
219
220
221
222
223
    waveform = torch.istft(specgram * angles,
                           n_fft=n_fft,
                           hop_length=hop_length,
                           win_length=win_length,
                           window=window,
                           length=length)
Vincent QB's avatar
Vincent QB committed
224
225

    # unpack batch
226
    waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
Vincent QB's avatar
Vincent QB committed
227
228

    return waveform
229
230


231
232
233
234
235
236
237
def amplitude_to_DB(
        x: Tensor,
        multiplier: float,
        amin: float,
        db_multiplier: float,
        top_db: Optional[float] = None
) -> Tensor:
238
    r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
239

240
241
    The output of each tensor in a batch depends on the maximum value of that tensor,
    and so may return different values for an audio clip split into snippets vs. a full clip.
242
243

    Args:
244
245
246
247

        x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
          the form `(..., freq, time)`. Batched inputs should include a channel dimension and
          have the form `(batch, channel, freq, time)`.
248
        multiplier (float): Use 10. for power and 20. for amplitude
249
        amin (float): Number to clamp ``x``
250
        db_multiplier (float): Log10(max(reference value and amin))
251
        top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
252
            is 80. (Default: ``None``)
253
254

    Returns:
255
        Tensor: Output tensor in decibel scale
256
    """
257
258
    x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
    x_db -= multiplier * db_multiplier
259
260

    if top_db is not None:
261
262
263
264
265
266
267
268
269
        # Expand batch
        shape = x_db.size()
        packed_channels = shape[-3] if x_db.dim() > 2 else 1
        x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])

        x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))

        # Repack batch
        x_db = x_db.reshape(shape)
270

271
    return x_db
272
273


274
275
276
277
278
def DB_to_amplitude(
        x: Tensor,
        ref: float,
        power: float
) -> Tensor:
279
280
281
    r"""Turn a tensor from the decibel scale to the power/amplitude scale.

    Args:
282
        x (Tensor): Input tensor before being converted to power/amplitude scale.
283
284
285
286
        ref (float): Reference which the output will be scaled by.
        power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.

    Returns:
287
        Tensor: Output tensor in power/amplitude scale.
288
289
290
291
    """
    return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)


292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
    r"""Convert Hz to Mels.

    Args:
        freqs (float): Frequencies in Hz
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        mels (float): Frequency in Mels
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 2595.0 * math.log10(1.0 + (freq / 700.0))

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    mels = (freq - f_min) / f_sp

    # Fill in the log-scale part
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    if freq >= min_log_hz:
        mels = min_log_mel + math.log(freq / min_log_hz) / logstep

    return mels


def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
    """Convert mel bin numbers to frequencies.

    Args:
        mels (Tensor): Mel frequencies
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        freqs (Tensor): Mels converted in Hz
    """

    if mel_scale not in ['slaney', 'htk']:
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 700.0 * (10.0**(mels / 2595.0) - 1.0)

    # Fill in the linear scale
    f_min = 0.0
    f_sp = 200.0 / 3
    freqs = f_min + f_sp * mels

    # And now the nonlinear scale
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    log_t = (mels >= min_log_mel)
    freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

    return freqs


359
360
361
362
363
def create_fb_matrix(
        n_freqs: int,
        f_min: float,
        f_max: float,
        n_mels: int,
Vincent QB's avatar
Vincent QB committed
364
        sample_rate: int,
365
366
        norm: Optional[str] = None,
        mel_scale: str = "htk",
367
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
368
    r"""Create a frequency bin conversion matrix.
Jason Lian's avatar
more  
Jason Lian committed
369

jamarshon's avatar
jamarshon committed
370
    Args:
371
        n_freqs (int): Number of frequencies to highlight/apply
engineerchuan's avatar
engineerchuan committed
372
373
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
374
        n_mels (int): Number of mel filterbanks
engineerchuan's avatar
engineerchuan committed
375
        sample_rate (int): Sample rate of the audio waveform
376
377
        norm (Optional[str]): If 'slaney', divide the triangular mel weights by the width of the mel band
        (area normalization). (Default: ``None``)
378
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Jason Lian's avatar
more  
Jason Lian committed
379

jamarshon's avatar
jamarshon committed
380
    Returns:
381
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
382
383
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
384
385
        size (..., ``n_freqs``), the applied result would be
        ``A * create_fb_matrix(A.size(-1), ...)``.
386
    """
387
388
389
390

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

391
    # freq bins
engineerchuan's avatar
engineerchuan committed
392
393
394
    # Equivalent filterbank construction by Librosa
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

Jason Lian's avatar
more  
Jason Lian committed
395
    # calculate mel freq bins
396
397
398
    m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
    m_max = _hz_to_mel(f_max, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
399
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
400
401
    f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
402
403
    # calculate the difference between each mel point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_mels + 1)
engineerchuan's avatar
engineerchuan committed
404
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_mels + 2)
Jason Lian's avatar
more  
Jason Lian committed
405
    # create overlapping triangles
406
    zero = torch.zeros(1)
407
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_mels)
408
409
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_mels)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))
Vincent QB's avatar
Vincent QB committed
410

411
    if norm is not None and norm == "slaney":
Vincent QB's avatar
Vincent QB committed
412
413
414
415
        # Slaney-style mel is scaled to be approx constant energy per channel
        enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels])
        fb *= enorm.unsqueeze(0)

416
417
418
419
420
421
422
    if (fb.max(dim=0).values == 0.).any():
        warnings.warn(
            "At least one mel filterbank has all zero values. "
            f"The value for `n_mels` ({n_mels}) may be set too high. "
            f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
        )

Jason Lian's avatar
more  
Jason Lian committed
423
424
425
    return fb


426
427
428
429
430
def create_dct(
        n_mfcc: int,
        n_mels: int,
        norm: Optional[str]
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
431
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
jamarshon's avatar
jamarshon committed
432
    normalized depending on norm.
Jason Lian's avatar
Jason Lian committed
433

jamarshon's avatar
jamarshon committed
434
    Args:
435
436
        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
437
        norm (str or None): Norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
438

jamarshon's avatar
jamarshon committed
439
    Returns:
440
        Tensor: The transformation matrix, to be right-multiplied to
441
        row-wise data of size (``n_mels``, ``n_mfcc``).
Jason Lian's avatar
more  
Jason Lian committed
442
443
    """
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
444
445
446
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
447
448
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
449
    else:
450
        assert norm == "ortho"
451
        dct[0] *= 1.0 / math.sqrt(2.0)
452
        dct *= math.sqrt(2.0 / float(n_mels))
453
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
454
455


456
457
458
459
def mu_law_encoding(
        x: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
460
    r"""Encode signal based on mu-law companding.  For more info see the
Jason Lian's avatar
Jason Lian committed
461
462
463
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This algorithm assumes the signal has been scaled to between -1 and 1 and
jamarshon's avatar
jamarshon committed
464
    returns a signal encoded with values from 0 to quantization_channels - 1.
Jason Lian's avatar
Jason Lian committed
465

jamarshon's avatar
jamarshon committed
466
    Args:
467
        x (Tensor): Input tensor
468
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
469

jamarshon's avatar
jamarshon committed
470
    Returns:
471
        Tensor: Input after mu-law encoding
Jason Lian's avatar
Jason Lian committed
472
    """
473
    mu = quantization_channels - 1.0
474
    if not x.is_floating_point():
475
476
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
477
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
478
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
479
480
481
    return x_mu


482
483
484
485
def mu_law_decoding(
        x_mu: Tensor,
        quantization_channels: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
486
    r"""Decode mu-law encoded signal.  For more info see the
Jason Lian's avatar
Jason Lian committed
487
488
489
490
491
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

jamarshon's avatar
jamarshon committed
492
    Args:
493
        x_mu (Tensor): Input tensor
494
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
495

jamarshon's avatar
jamarshon committed
496
    Returns:
497
        Tensor: Input after mu-law decoding
Jason Lian's avatar
Jason Lian committed
498
    """
499
    mu = quantization_channels - 1.0
500
    if not x_mu.is_floating_point():
501
502
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
503
504
    x = ((x_mu) / mu) * 2 - 1.0
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
Jason Lian's avatar
more  
Jason Lian committed
505
    return x
506
507


508
509
510
511
def complex_norm(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tensor:
512
    r"""Compute the norm of complex tensor input.
513
514

    Args:
515
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
516
        power (float): Power of the norm. (Default: `1.0`).
517
518

    Returns:
519
        Tensor: Power of the normed input tensor. Shape of `(..., )`
520
    """
521
522
523
524

    # Replace by torch.norm once issue is fixed
    # https://github.com/pytorch/pytorch/issues/34279
    return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
525
526


527
528
529
def angle(
        complex_tensor: Tensor
) -> Tensor:
530
531
532
    r"""Compute the angle of complex tensor input.

    Args:
533
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
534
535

    Return:
536
        Tensor: Angle of a complex tensor. Shape of `(..., )`
537
538
539
540
    """
    return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])


541
542
543
544
def magphase(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tuple[Tensor, Tensor]:
545
    r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
546
547

    Args:
548
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
549
550
551
        power (float): Power of the norm. (Default: `1.0`)

    Returns:
552
        (Tensor, Tensor): The magnitude and phase of the complex tensor
553
554
555
556
557
558
    """
    mag = complex_norm(complex_tensor, power)
    phase = angle(complex_tensor)
    return mag, phase


559
560
561
562
563
def phase_vocoder(
        complex_specgrams: Tensor,
        rate: float,
        phase_advance: Tensor
) -> Tensor:
564
    r"""Given a STFT tensor, speed up in time without modifying pitch by a
565
    factor of ``rate``.
Vincent QB's avatar
Vincent QB committed
566

567
    Args:
568
569
570
        complex_specgrams (Tensor):
            Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)``
            or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype.
571
        rate (float): Speed-up factor
572
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Vincent QB's avatar
Vincent QB committed
573

574
    Returns:
575
576
577
        Tensor:
            Stretched spectrogram. The resulting tensor is of the same dtype as the input
            spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Vincent QB's avatar
Vincent QB committed
578

579
580
581
582
583
584
585
586
587
588
589
590
    Example - With Tensor of complex dtype
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time)
        >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
        >>> rate = 1.3 # Speed up by 30%
        >>> phase_advance = torch.linspace(
        >>>    0, math.pi * hop_length, freq)[..., None]
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
        torch.Size([2, 1025, 231])

    Example - With Tensor of real dtype and extra dimension for complex field
591
592
593
594
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time, complex=2)
        >>> complex_specgrams = torch.randn(2, freq, 300, 2)
        >>> rate = 1.3 # Speed up by 30%
595
        >>> phase_advance = torch.linspace(
596
        >>>    0, math.pi * hop_length, freq)[..., None]
597
598
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
599
        torch.Size([2, 1025, 231, 2])
600
    """
601
602
603
604
605
606
607
608
609
610
611
612
    if rate == 1.0:
        return complex_specgrams

    if not complex_specgrams.is_complex() and complex_specgrams.size(-1) != 2:
        raise ValueError(
            "complex_specgrams must be either native complex tensors or "
            "real valued tensors with shape (..., 2)")

    is_complex = complex_specgrams.is_complex()

    if not is_complex:
        complex_specgrams = torch.view_as_complex(complex_specgrams)
613

614
615
    # pack batch
    shape = complex_specgrams.size()
616
617
618
619
620
621
622
623
624
625
626
    complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))

    # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
    # Note torch.real is a view so it does not incur any memory copy.
    real_dtype = torch.real(complex_specgrams).dtype
    time_steps = torch.arange(
        0,
        complex_specgrams.size(-1),
        rate,
        device=complex_specgrams.device,
        dtype=real_dtype)
627

628
    alphas = time_steps % 1.0
629
    phase_0 = complex_specgrams[..., :1].angle()
630
631

    # Time Padding
632
    complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
633

634
    # (new_bins, freq, 2)
635
636
    complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
    complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
637

638
639
    angle_0 = complex_specgrams_0.angle()
    angle_1 = complex_specgrams_1.angle()
640

641
642
    norm_0 = complex_specgrams_0.abs()
    norm_1 = complex_specgrams_1.abs()
643
644
645
646
647
648

    phase = angle_1 - angle_0 - phase_advance
    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

    # Compute Phase Accum
    phase = phase + phase_advance
Vincent QB's avatar
Vincent QB committed
649
    phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
650
651
652
653
    phase_acc = torch.cumsum(phase, -1)

    mag = alphas * norm_1 + (1 - alphas) * norm_0

654
    complex_specgrams_stretch = torch.polar(mag, phase_acc)
655

656
    # unpack batch
657
    complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
658

659
660
    if not is_complex:
        return torch.view_as_real(complex_specgrams_stretch)
661
    return complex_specgrams_stretch
662
663


664
665
666
667
668
669
def mask_along_axis_iid(
        specgrams: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
670
671
672
673
674
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.

    Args:
Vincent QB's avatar
Vincent QB committed
675
        specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
676
677
678
679
680
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)

    Returns:
681
        Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
682
683
684
685
686
    """

    if axis != 2 and axis != 3:
        raise ValueError('Only Frequency and Time masking are supported')

687
688
689
690
691
    device = specgrams.device
    dtype = specgrams.dtype

    value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
    min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
692
693

    # Create broadcastable mask
694
695
696
    mask_start = min_value[..., None, None]
    mask_end = (min_value + value)[..., None, None]
    mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
697
698
699
700
701
702
703
704
705

    # Per batch example masking
    specgrams = specgrams.transpose(axis, -1)
    specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value)
    specgrams = specgrams.transpose(axis, -1)

    return specgrams


706
707
708
709
710
711
def mask_along_axis(
        specgram: Tensor,
        mask_param: int,
        mask_value: float,
        axis: int
) -> Tensor:
712
713
714
715
716
717
    r"""
    Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
    ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
    All examples will have the same mask interval.

    Args:
Vincent QB's avatar
Vincent QB committed
718
        specgram (Tensor): Real spectrogram (channel, freq, time)
719
720
721
722
723
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)

    Returns:
724
        Tensor: Masked spectrogram of dimensions (channel, freq, time)
725
726
    """

727
728
    # pack batch
    shape = specgram.size()
729
    specgram = specgram.reshape([-1] + list(shape[-2:]))
730

731
732
733
734
735
736
737
738
739
740
741
742
743
744
    value = torch.rand(1) * mask_param
    min_value = torch.rand(1) * (specgram.size(axis) - value)

    mask_start = (min_value.long()).squeeze()
    mask_end = (min_value.long() + value.long()).squeeze()

    assert mask_end - mask_start < mask_param
    if axis == 1:
        specgram[:, mask_start:mask_end] = mask_value
    elif axis == 2:
        specgram[:, :, mask_start:mask_end] = mask_value
    else:
        raise ValueError('Only Frequency and Time masking are supported')

745
    # unpack batch
746
    specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
747

Vincent QB's avatar
Vincent QB committed
748
    return specgram
749
750


751
752
753
754
755
def compute_deltas(
        specgram: Tensor,
        win_length: int = 5,
        mode: str = "replicate"
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
756
757
758
    r"""Compute delta coefficients of a tensor, usually a spectrogram:

    .. math::
759
       d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
Vincent QB's avatar
Vincent QB committed
760
761
762

    where :math:`d_t` is the deltas at time :math:`t`,
    :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
763
    :math:`N` is ``(win_length-1)//2``.
Vincent QB's avatar
Vincent QB committed
764
765

    Args:
766
767
768
        specgram (Tensor): Tensor of audio of dimension (..., freq, time)
        win_length (int, optional): The window length used for computing delta (Default: ``5``)
        mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Vincent QB's avatar
Vincent QB committed
769
770

    Returns:
771
        Tensor: Tensor of deltas of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
772
773
774
775
776
777

    Example
        >>> specgram = torch.randn(1, 40, 1000)
        >>> delta = compute_deltas(specgram)
        >>> delta2 = compute_deltas(delta)
    """
778
779
    device = specgram.device
    dtype = specgram.dtype
Vincent QB's avatar
Vincent QB committed
780

Vincent QB's avatar
Vincent QB committed
781
782
    # pack batch
    shape = specgram.size()
783
    specgram = specgram.reshape(1, -1, shape[-1])
Vincent QB's avatar
Vincent QB committed
784

Vincent QB's avatar
Vincent QB committed
785
786
787
788
789
790
791
792
793
    assert win_length >= 3

    n = (win_length - 1) // 2

    # twice sum of integer squared
    denom = n * (n + 1) * (2 * n + 1) / 3

    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)

794
    kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
Vincent QB's avatar
Vincent QB committed
795

Vincent QB's avatar
Vincent QB committed
796
797
798
    output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom

    # unpack batch
799
    output = output.reshape(shape)
Vincent QB's avatar
Vincent QB committed
800
801

    return output
Vincent QB's avatar
Vincent QB committed
802
803


804
805
806
807
808
809
def _compute_nccf(
        waveform: Tensor,
        sample_rate: int,
        frame_time: float,
        freq_low: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
810
811
812
813
814
815
816
817
818
    r"""
    Compute Normalized Cross-Correlation Function (NCCF).

    .. math::
        \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},

    where
    :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
    :math:`w` is the waveform,
819
    :math:`N` is the length of a frame,
Vincent QB's avatar
Vincent QB committed
820
821
822
823
824
825
826
    :math:`b_i` is the beginning of frame :math:`i`,
    :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
    """

    EPSILON = 10 ** (-9)

    # Number of lags to check
827
    lags = int(math.ceil(sample_rate / freq_low))
Vincent QB's avatar
Vincent QB committed
828
829
830
831

    frame_size = int(math.ceil(sample_rate * frame_time))

    waveform_length = waveform.size()[-1]
832
    num_of_frames = int(math.ceil(waveform_length / frame_size))
Vincent QB's avatar
Vincent QB committed
833
834
835
836
837
838
839

    p = lags + num_of_frames * frame_size - waveform_length
    waveform = torch.nn.functional.pad(waveform, (0, p))

    # Compute lags
    output_lag = []
    for lag in range(1, lags + 1):
840
841
        s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
        s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
Vincent QB's avatar
Vincent QB committed
842
843
844
845
846
847
848
849
850
851
852
853
854
855

        output_frames = (
            (s1 * s2).sum(-1)
            / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
            / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
        )

        output_lag.append(output_frames.unsqueeze(-1))

    nccf = torch.cat(output_lag, -1)

    return nccf


856
857
858
859
860
def _combine_max(
        a: Tuple[Tensor, Tensor],
        b: Tuple[Tensor, Tensor],
        thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
Vincent QB's avatar
Vincent QB committed
861
862
863
864
865
866
867
868
869
    """
    Take value from first if bigger than a multiplicative factor of the second, elementwise.
    """
    mask = (a[0] > thresh * b[0])
    values = mask * a[0] + ~mask * b[0]
    indices = mask * a[1] + ~mask * b[1]
    return values, indices


870
871
872
873
874
def _find_max_per_frame(
        nccf: Tensor,
        sample_rate: int,
        freq_high: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
875
876
877
878
879
880
881
882
    r"""
    For each frame, take the highest value of NCCF,
    apply centered median smoothing, and convert to frequency.

    Note: If the max among all the lags is very close
    to the first half of lags, then the latter is taken.
    """

883
    lag_min = int(math.ceil(sample_rate / freq_high))
Vincent QB's avatar
Vincent QB committed
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902

    # Find near enough max that is smallest

    best = torch.max(nccf[..., lag_min:], -1)

    half_size = nccf.shape[-1] // 2
    half = torch.max(nccf[..., lag_min:half_size], -1)

    best = _combine_max(half, best)
    indices = best[1]

    # Add back minimal lag
    indices += lag_min
    # Add 1 empirical calibration offset
    indices += 1

    return indices


903
904
905
906
def _median_smoothing(
        indices: Tensor,
        win_length: int
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
    r"""
    Apply median smoothing to the 1D tensor over the given window.
    """

    # Centered windowed
    pad_length = (win_length - 1) // 2

    # "replicate" padding in any dimension
    indices = torch.nn.functional.pad(
        indices, (pad_length, 0), mode="constant", value=0.
    )

    indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
    roll = indices.unfold(-1, win_length, 1)

    values, _ = torch.median(roll, -1)
    return values


def detect_pitch_frequency(
927
928
929
930
931
932
933
        waveform: Tensor,
        sample_rate: int,
        frame_time: float = 10 ** (-2),
        win_length: int = 30,
        freq_low: int = 85,
        freq_high: int = 3400,
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
934
935
936
937
938
    r"""Detect pitch frequency.

    It is implemented using normalized cross-correlation function and median smoothing.

    Args:
939
        waveform (Tensor): Tensor of audio of dimension (..., freq, time)
Vincent QB's avatar
Vincent QB committed
940
        sample_rate (int): The sample rate of the waveform (Hz)
941
942
943
944
        frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
        win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
        freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
        freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Vincent QB's avatar
Vincent QB committed
945
946

    Returns:
947
        Tensor: Tensor of freq of dimension (..., frame)
Vincent QB's avatar
Vincent QB committed
948
    """
Vincent QB's avatar
Vincent QB committed
949
    # pack batch
950
    shape = list(waveform.size())
951
    waveform = waveform.reshape([-1] + shape[-1:])
Vincent QB's avatar
Vincent QB committed
952

Vincent QB's avatar
Vincent QB committed
953
954
955
956
957
958
959
960
    nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
    indices = _find_max_per_frame(nccf, sample_rate, freq_high)
    indices = _median_smoothing(indices, win_length)

    # Convert indices to frequency
    EPSILON = 10 ** (-9)
    freq = sample_rate / (EPSILON + indices.to(torch.float))

Vincent QB's avatar
Vincent QB committed
961
    # unpack batch
962
    freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
Vincent QB's avatar
Vincent QB committed
963

Vincent QB's avatar
Vincent QB committed
964
    return freq
wanglong001's avatar
wanglong001 committed
965
966
967


def sliding_window_cmn(
968
    specgram: Tensor,
wanglong001's avatar
wanglong001 committed
969
970
971
972
973
974
975
976
977
    cmn_window: int = 600,
    min_cmn_window: int = 100,
    center: bool = False,
    norm_vars: bool = False,
) -> Tensor:
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

    Args:
978
        specgram (Tensor): Tensor of audio of dimension (..., time, freq)
wanglong001's avatar
wanglong001 committed
979
980
981
982
983
984
985
986
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

    Returns:
987
        Tensor: Tensor matching input shape (..., freq, time)
wanglong001's avatar
wanglong001 committed
988
    """
989
    input_shape = specgram.shape
990
    num_frames, num_feats = input_shape[-2:]
991
992
    specgram = specgram.view(-1, num_frames, num_feats)
    num_channels = specgram.shape[0]
993

994
995
    dtype = specgram.dtype
    device = specgram.device
wanglong001's avatar
wanglong001 committed
996
    last_window_start = last_window_end = -1
997
998
    cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
    cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
999
    cmn_specgram = torch.zeros(
1000
        num_channels, num_frames, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    for t in range(num_frames):
        window_start = 0
        window_end = 0
        if center:
            window_start = t - cmn_window // 2
            window_end = window_start + cmn_window
        else:
            window_start = t - cmn_window
            window_end = t + 1
        if window_start < 0:
            window_end -= window_start
            window_start = 0
        if not center:
            if window_end > t:
                window_end = max(t + 1, min_cmn_window)
        if window_end > num_frames:
            window_start -= (window_end - num_frames)
            window_end = num_frames
            if window_start < 0:
                window_start = 0
        if last_window_start == -1:
1022
            input_part = specgram[:, window_start: window_end - window_start, :]
1023
            cur_sum += torch.sum(input_part, 1)
wanglong001's avatar
wanglong001 committed
1024
            if norm_vars:
1025
                cur_sumsq += torch.cumsum(input_part ** 2, 1)[:, -1, :]
wanglong001's avatar
wanglong001 committed
1026
1027
        else:
            if window_start > last_window_start:
1028
                frame_to_remove = specgram[:, last_window_start, :]
wanglong001's avatar
wanglong001 committed
1029
1030
1031
1032
                cur_sum -= frame_to_remove
                if norm_vars:
                    cur_sumsq -= (frame_to_remove ** 2)
            if window_end > last_window_end:
1033
                frame_to_add = specgram[:, last_window_end, :]
wanglong001's avatar
wanglong001 committed
1034
1035
1036
1037
1038
1039
                cur_sum += frame_to_add
                if norm_vars:
                    cur_sumsq += (frame_to_add ** 2)
        window_frames = window_end - window_start
        last_window_start = window_start
        last_window_end = window_end
1040
        cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
wanglong001's avatar
wanglong001 committed
1041
1042
        if norm_vars:
            if window_frames == 1:
1043
                cmn_specgram[:, t, :] = torch.zeros(
1044
                    num_channels, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1045
1046
1047
1048
1049
            else:
                variance = cur_sumsq
                variance = variance / window_frames
                variance -= ((cur_sum ** 2) / (window_frames ** 2))
                variance = torch.pow(variance, -0.5)
1050
                cmn_specgram[:, t, :] *= variance
1051

1052
    cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
1053
    if len(input_shape) == 2:
1054
1055
        cmn_specgram = cmn_specgram.squeeze(0)
    return cmn_specgram
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090


def spectral_centroid(
        waveform: Tensor,
        sample_rate: int,
        pad: int,
        window: Tensor,
        n_fft: int,
        hop_length: int,
        win_length: int,
) -> Tensor:
    r"""
    Compute the spectral centroid for each channel along the time axis.

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
        waveform (Tensor): Tensor of audio of dimension (..., time)
        sample_rate (int): Sample rate of the audio waveform
        pad (int): Two sided padding of signal
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size

    Returns:
        Tensor: Dimension (..., time)
    """
    specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
                           win_length=win_length, power=1., normalized=False)
    freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
                           device=specgram.device).reshape((-1, 1))
    freq_dim = -2
    return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
moto's avatar
moto committed
1091
1092


Caroline Chen's avatar
Caroline Chen committed
1093
@_mod_utils.requires_sox()
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
def apply_codec(
    waveform: Tensor,
    sample_rate: int,
    format: str,
    channels_first: bool = True,
    compression: Optional[float] = None,
    encoding: Optional[str] = None,
    bits_per_sample: Optional[int] = None,
) -> Tensor:
    r"""
Vincent QB's avatar
Vincent QB committed
1104
1105
    Apply codecs as a form of augmentation.

1106
    Args:
Vincent QB's avatar
Vincent QB committed
1107
1108
1109
        waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```.
        sample_rate (int): Sample rate of the audio waveform.
        format (str): File format.
1110
1111
1112
1113
        channels_first (bool):
            When True, both the input and output Tensor have dimension ``[channel, time]``.
            Otherwise, they have dimension ``[time, channel]``.
        compression (float): Used for formats other than WAV.
Vincent QB's avatar
Vincent QB committed
1114
            For mor details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1115
        encoding (str, optional): Changes the encoding for the supported formats.
Vincent QB's avatar
Vincent QB committed
1116
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1117
        bits_per_sample (int, optional): Changes the bit depth for the supported formats.
Vincent QB's avatar
Vincent QB committed
1118
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1119
1120
1121

    Returns:
        torch.Tensor: Resulting Tensor.
Vincent QB's avatar
Vincent QB committed
1122
        If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``.
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
    """
    bytes = io.BytesIO()
    torchaudio.backend.sox_io_backend.save(bytes,
                                           waveform,
                                           sample_rate,
                                           channels_first,
                                           compression,
                                           format,
                                           encoding,
                                           bits_per_sample
                                           )
    bytes.seek(0)
    augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
        bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
    return augmented


1140
@_mod_utils.requires_kaldi()
moto's avatar
moto committed
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
def compute_kaldi_pitch(
        waveform: torch.Tensor,
        sample_rate: float,
        frame_length: float = 25.0,
        frame_shift: float = 10.0,
        min_f0: float = 50,
        max_f0: float = 400,
        soft_min_f0: float = 10.0,
        penalty_factor: float = 0.1,
        lowpass_cutoff: float = 1000,
        resample_frequency: float = 4000,
        delta_pitch: float = 0.005,
        nccf_ballast: float = 7000,
        lowpass_filter_width: int = 1,
        upsample_filter_width: int = 5,
        max_frames_latency: int = 0,
        frames_per_chunk: int = 0,
        simulate_first_pass_online: bool = False,
        recompute_frame: int = 500,
        snip_edges: bool = True,
) -> torch.Tensor:
    """Extract pitch based on method described in [1].

    This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.

    Args:
        waveform (Tensor):
            The input waveform of shape `(..., time)`.
        sample_rate (float):
            Sample rate of `waveform`.
        frame_length (float, optional):
moto's avatar
moto committed
1172
            Frame length in milliseconds. (default: 25.0)
moto's avatar
moto committed
1173
        frame_shift (float, optional):
moto's avatar
moto committed
1174
            Frame shift in milliseconds. (default: 10.0)
moto's avatar
moto committed
1175
        min_f0 (float, optional):
moto's avatar
moto committed
1176
            Minimum F0 to search for (Hz)  (default: 50.0)
moto's avatar
moto committed
1177
        max_f0 (float, optional):
moto's avatar
moto committed
1178
            Maximum F0 to search for (Hz)  (default: 400.0)
moto's avatar
moto committed
1179
        soft_min_f0 (float, optional):
moto's avatar
moto committed
1180
            Minimum f0, applied in soft way, must not exceed min-f0  (default: 10.0)
moto's avatar
moto committed
1181
        penalty_factor (float, optional):
moto's avatar
moto committed
1182
            Cost factor for FO change.  (default: 0.1)
moto's avatar
moto committed
1183
        lowpass_cutoff (float, optional):
moto's avatar
moto committed
1184
            Cutoff frequency for LowPass filter (Hz) (default: 1000)
moto's avatar
moto committed
1185
1186
        resample_frequency (float, optional):
            Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
moto's avatar
moto committed
1187
            (default: 4000)
moto's avatar
moto committed
1188
        delta_pitch( float, optional):
moto's avatar
moto committed
1189
            Smallest relative change in pitch that our algorithm measures. (default: 0.005)
moto's avatar
moto committed
1190
        nccf_ballast (float, optional):
moto's avatar
moto committed
1191
            Increasing this factor reduces NCCF for quiet frames (default: 7000)
moto's avatar
moto committed
1192
1193
        lowpass_filter_width (int, optional):
            Integer that determines filter width of lowpass filter, more gives sharper filter.
moto's avatar
moto committed
1194
            (default: 1)
moto's avatar
moto committed
1195
        upsample_filter_width (int, optional):
moto's avatar
moto committed
1196
            Integer that determines filter width when upsampling NCCF. (default: 5)
moto's avatar
moto committed
1197
1198
1199
        max_frames_latency (int, optional):
            Maximum number of frames of latency that we allow pitch tracking to introduce into
            the feature processing (affects output only if ``frames_per_chunk > 0`` and
moto's avatar
moto committed
1200
            ``simulate_first_pass_online=True``) (default: 0)
moto's avatar
moto committed
1201
        frames_per_chunk (int, optional):
moto's avatar
moto committed
1202
            The number of frames used for energy normalization. (default: 0)
moto's avatar
moto committed
1203
1204
1205
        simulate_first_pass_online (bool, optional):
            If true, the function will output features that correspond to what an online decoder
            would see in the first pass of decoding -- not the final version of the features,
moto's avatar
moto committed
1206
            which is the default. (default: False)
moto's avatar
moto committed
1207
1208
1209
1210
1211
            Relevant if ``frames_per_chunk > 0``.
        recompute_frame (int, optional):
            Only relevant for compatibility with online pitch extraction.
            A non-critical parameter; the frame at which we recompute some of the forward pointers,
            after revising our estimate of the signal energy.
moto's avatar
moto committed
1212
            Relevant if ``frames_per_chunk > 0``. (default: 500)
moto's avatar
moto committed
1213
1214
1215
        snip_edges (bool, optional):
            If this is set to false, the incomplete frames near the ending edge won't be snipped,
            so that the number of frames is the file size divided by the frame-shift.
moto's avatar
moto committed
1216
            This makes different types of features give the same number of frames. (default: True)
moto's avatar
moto committed
1217
1218

    Returns:
moto's avatar
moto committed
1219
1220
       Tensor: Pitch feature. Shape: ``(batch, frames 2)`` where the last dimension
       corresponds to pitch and NCCF.
moto's avatar
moto committed
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242

    Reference:
        - A pitch extraction algorithm tuned for automatic speech recognition

          P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur

          2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),

          Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049.
    """
    shape = waveform.shape
    waveform = waveform.reshape(-1, shape[-1])
    result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
        waveform, sample_rate, frame_length, frame_shift,
        min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff,
        resample_frequency, delta_pitch, nccf_ballast,
        lowpass_filter_width, upsample_filter_width, max_frames_latency,
        frames_per_chunk, simulate_first_pass_online, recompute_frame,
        snip_edges,
    )
    result = result.reshape(shape[:-1] + result.shape[-2:])
    return result
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344


def _get_sinc_resample_kernel(orig_freq: int, new_freq: int, lowpass_filter_width: int,
                              device: torch.device, dtype: torch.dtype):
    assert lowpass_filter_width > 0
    kernels = []
    base_freq = min(orig_freq, new_freq)
    # This will perform antialiasing filtering by removing the highest frequencies.
    # At first I thought I only needed this when downsampling, but when upsampling
    # you will get edge artifacts without this, as the edge is equivalent to zero padding,
    # which will add high freq artifacts.
    base_freq *= 0.99

    # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
    # using the sinc interpolation formula:
    #   x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
    # We can then sample the function x(t) with a different sample rate:
    #    y[j] = x(j / new_freq)
    # or,
    #    y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

    # We see here that y[j] is the convolution of x[i] with a specific filter, for which
    # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
    # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
    # Indeed:
    # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
    #                 = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
    #                 = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
    # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
    # This will explain the F.conv1d after, with a stride of orig_freq.
    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
    # they will have a lot of almost zero values to the left or to the right...
    # There is probably a way to evaluate those filters more efficiently, but this is kept for
    # future work.
    idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)

    for i in range(new_freq):
        t = (-i / new_freq + idx / orig_freq) * base_freq
        t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
        t *= math.pi
        # we do not use torch.hann_window here as we need to evaluate the window
        # at specific positions, not over a regular grid.
        window = torch.cos(t / lowpass_filter_width / 2)**2
        kernel = torch.where(t == 0, torch.tensor(1.).to(t), torch.sin(t) / t)
        kernel.mul_(window)
        kernels.append(kernel)

    scale = base_freq / orig_freq
    return torch.stack(kernels).view(new_freq, 1, -1).mul_(scale), width


def resample(
        waveform: Tensor,
        orig_freq: float,
        new_freq: float,
        lowpass_filter_width: int = 6
) -> Tensor:
    r"""Resamples the waveform at the new frequency. This matches Kaldi's OfflineFeatureTpl ResampleWaveform
    which uses a LinearResample (resample a signal at linearly spaced intervals to upsample/downsample
    a signal). LinearResample (LR) means that the output signal is at linearly spaced intervals (i.e
    the output signal has a frequency of ``new_freq``). It uses sinc/bandlimited interpolation to
    upsample/downsample the signal.

    https://ccrma.stanford.edu/~jos/resample/Theory_Ideal_Bandlimited_Interpolation.html
    https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56

    Args:
        waveform (Tensor): The input signal of dimension (..., time)
        orig_freq (float): The original frequency of the signal
        new_freq (float): The desired frequency
        lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
            but less efficient. We suggest around 4 to 10 for normal use. (Default: ``6``)

    Returns:
        Tensor: The waveform at the new frequency of dimension (..., time).
    """
    # pack batch
    shape = waveform.size()
    waveform = waveform.view(-1, shape[-1])

    assert orig_freq > 0.0 and new_freq > 0.0

    orig_freq = int(orig_freq)
    new_freq = int(new_freq)
    gcd = math.gcd(orig_freq, new_freq)
    orig_freq = orig_freq // gcd
    new_freq = new_freq // gcd

    kernel, width = _get_sinc_resample_kernel(orig_freq, new_freq, lowpass_filter_width,
                                              waveform.device, waveform.dtype)

    num_wavs, length = waveform.shape
    waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
    resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
    resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
    target_length = int(math.ceil(new_freq * length / orig_freq))
    resampled = resampled[..., :target_length]

    # unpack batch
    resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
    return resampled