functional.py 74.4 KB
Newer Older
1
2
# -*- coding: utf-8 -*-

3
import io
4
import math
moto's avatar
moto committed
5
import warnings
6
from collections.abc import Sequence
7
from typing import Optional, Tuple, Union
Vincent QB's avatar
Vincent QB committed
8

Jason Lian's avatar
Jason Lian committed
9
import torch
10
import torchaudio
11
from torch import Tensor
Caroline Chen's avatar
Caroline Chen committed
12
from torchaudio._internal import module_utils as _mod_utils
Jason Lian's avatar
Jason Lian committed
13

Jason Lian's avatar
pre  
Jason Lian committed
14
__all__ = [
15
    "spectrogram",
16
    "inverse_spectrogram",
17
    "griffinlim",
18
    "amplitude_to_DB",
19
20
    "DB_to_amplitude",
    "compute_deltas",
moto's avatar
moto committed
21
    "compute_kaldi_pitch",
22
    "melscale_fbanks",
23
    "linear_fbanks",
24
    "create_dct",
25
26
27
    "compute_deltas",
    "detect_pitch_frequency",
    "DB_to_amplitude",
28
29
30
    "mu_law_encoding",
    "mu_law_decoding",
    "phase_vocoder",
31
32
33
    "mask_along_axis",
    "mask_along_axis_iid",
    "sliding_window_cmn",
34
    "spectral_centroid",
35
    "apply_codec",
36
    "resample",
yangarbiter's avatar
yangarbiter committed
37
    "edit_distance",
38
    "pitch_shift",
39
    "rnnt_loss",
40
    "psd",
41
    "mvdr_weights_souden",
42
    "mvdr_weights_rtf",
43
    "rtf_evd",
44
    "rtf_power",
45
    "apply_beamforming",
Jason Lian's avatar
pre  
Jason Lian committed
46
47
]

Vincent QB's avatar
Vincent QB committed
48

49
def spectrogram(
50
51
52
53
54
55
56
57
58
59
60
61
    waveform: Tensor,
    pad: int,
    window: Tensor,
    n_fft: int,
    hop_length: int,
    win_length: int,
    power: Optional[float],
    normalized: bool,
    center: bool = True,
    pad_mode: str = "reflect",
    onesided: bool = True,
    return_complex: Optional[bool] = None,
62
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
63
    r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
64
    The spectrogram can be either magnitude-only or complex.
jamarshon's avatar
jamarshon committed
65

moto's avatar
moto committed
66
67
68
69
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

jamarshon's avatar
jamarshon committed
70
    Args:
71
        waveform (Tensor): Tensor of audio of dimension `(..., time)`
jamarshon's avatar
jamarshon committed
72
        pad (int): Two sided padding of signal
73
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
74
        n_fft (int): Size of FFT
75
76
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
77
        power (float or None): Exponent for the magnitude spectrogram,
jamarshon's avatar
jamarshon committed
78
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
79
            If None, then the complex spectrum is returned instead.
80
        normalized (bool): Whether to normalize by magnitude after stft
81
82
83
84
85
86
87
        center (bool, optional): whether to pad :attr:`waveform` on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. Default: ``"reflect"``
        onesided (bool, optional): controls whether to return half of results to
            avoid redundancy. Default: ``True``
88
        return_complex (bool, optional):
89
            Deprecated and not used.
jamarshon's avatar
jamarshon committed
90
91

    Returns:
92
        Tensor: Dimension `(..., freq, time)`, freq is
Vincent QB's avatar
Vincent QB committed
93
        ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Vincent QB's avatar
Vincent QB committed
94
        Fourier bins, and time is the number of window hops (n_frame).
Jason Lian's avatar
Jason Lian committed
95
    """
96
    if return_complex is not None:
97
        warnings.warn(
98
99
100
            "`return_complex` argument is now deprecated and is not effective."
            "`torchaudio.functional.spectrogram(power=None)` always returns a tensor with "
            "complex dtype. Please remove the argument in the function call."
101
102
        )

Jason Lian's avatar
Jason Lian committed
103
    if pad > 0:
104
        # TODO add "with torch.no_grad():" back when JIT supports it
105
        waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
Jason Lian's avatar
Jason Lian committed
106

107
108
    # pack batch
    shape = waveform.size()
109
    waveform = waveform.reshape(-1, shape[-1])
110

Jason Lian's avatar
Jason Lian committed
111
    # default values are consistent with librosa.core.spectrum._spectrogram
112
113
114
115
116
117
    spec_f = torch.stft(
        input=waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
118
119
        center=center,
        pad_mode=pad_mode,
120
        normalized=False,
121
        onesided=onesided,
122
        return_complex=True,
123
    )
124

125
    # unpack batch
126
    spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-2:])
127

128
    if normalized:
129
        spec_f /= window.pow(2.0).sum().sqrt()
130
    if power is not None:
131
132
133
        if power == 1.0:
            return spec_f.abs()
        return spec_f.abs().pow(power)
134
    return spec_f
Jason Lian's avatar
more  
Jason Lian committed
135
136


137
def inverse_spectrogram(
138
139
140
141
142
143
144
145
146
147
148
    spectrogram: Tensor,
    length: Optional[int],
    pad: int,
    window: Tensor,
    n_fft: int,
    hop_length: int,
    win_length: int,
    normalized: bool,
    center: bool = True,
    pad_mode: str = "reflect",
    onesided: bool = True,
149
150
151
152
) -> Tensor:
    r"""Create an inverse spectrogram or a batch of inverse spectrograms from the provided
    complex-valued spectrogram.

moto's avatar
moto committed
153
154
155
156
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

157
158
    Args:
        spectrogram (Tensor): Complex tensor of audio of dimension (..., freq, time).
159
        length (int or None): The output length of the waveform.
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        pad (int): Two sided padding of signal. It is only effective when ``length`` is provided.
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size
        normalized (bool): Whether the stft output was normalized by magnitude
        center (bool, optional): whether the waveform was padded on both sides so
            that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
            Default: ``True``
        pad_mode (string, optional): controls the padding method used when
            :attr:`center` is ``True``. This parameter is provided for compatibility with the
            spectrogram function and is not used. Default: ``"reflect"``
        onesided (bool, optional): controls whether spectrogram was done in onesided mode.
            Default: ``True``

    Returns:
176
        Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
177
178
    """

179
180
    if not spectrogram.is_complex():
        raise ValueError("Expected `spectrogram` to be complex dtype.")
181
182

    if normalized:
183
        spectrogram = spectrogram * window.pow(2.0).sum().sqrt()
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

    # pack batch
    shape = spectrogram.size()
    spectrogram = spectrogram.reshape(-1, shape[-2], shape[-1])

    # default values are consistent with librosa.core.spectrum._spectrogram
    waveform = torch.istft(
        input=spectrogram,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=center,
        normalized=False,
        onesided=onesided,
        length=length + 2 * pad if length is not None else None,
        return_complex=False,
    )

    if length is not None and pad > 0:
        # remove padding from front and back
        waveform = waveform[:, pad:-pad]

    # unpack batch
    waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])

    return waveform


213
214
215
216
217
218
219
def _get_complex_dtype(real_dtype: torch.dtype):
    if real_dtype == torch.double:
        return torch.cdouble
    if real_dtype == torch.float:
        return torch.cfloat
    if real_dtype == torch.half:
        return torch.complex32
220
    raise ValueError(f"Unexpected dtype {real_dtype}")
221
222


223
def griffinlim(
224
225
226
227
228
229
230
231
232
233
    specgram: Tensor,
    window: Tensor,
    n_fft: int,
    hop_length: int,
    win_length: int,
    power: float,
    n_iter: int,
    momentum: float,
    length: Optional[int],
    rand_init: bool,
234
) -> Tensor:
235
    r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
moto's avatar
moto committed
236

moto's avatar
moto committed
237
238
239
240
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

moto's avatar
moto committed
241
    Implementation ported from
242
243
    *librosa* [:footcite:`brian_mcfee-proc-scipy-2015`], *A fast Griffin-Lim algorithm* [:footcite:`6701851`]
    and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
244
245

    Args:
246
        specgram (Tensor): A magnitude-only STFT spectrogram of dimension `(..., freq, frames)`
247
            where freq is ``n_fft // 2 + 1``.
248
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
249
250
251
252
        n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
        hop_length (int): Length of hop between STFT windows. (
            Default: ``win_length // 2``)
        win_length (int): Window size. (Default: ``n_fft``)
253
        power (float): Exponent for the magnitude spectrogram,
254
            (must be > 0) e.g., 1 for energy, 2 for power, etc.
255
256
257
        n_iter (int): Number of iteration for phase recovery process.
        momentum (float): The momentum parameter for fast Griffin-Lim.
            Setting this to 0 recovers the original Griffin-Lim method.
258
259
260
            Values near 1 can lead to faster convergence, but above 1 may not converge.
        length (int or None): Array length of the expected output.
        rand_init (bool): Initializes phase randomly if True, to zero otherwise.
261
262

    Returns:
263
        Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
264
    """
265
266
    assert momentum < 1, "momentum={} > 1 can be unstable".format(momentum)
    assert momentum >= 0, "momentum={} < 0".format(momentum)
267

Vincent QB's avatar
Vincent QB committed
268
269
    # pack batch
    shape = specgram.size()
270
    specgram = specgram.reshape([-1] + list(shape[-2:]))
Vincent QB's avatar
Vincent QB committed
271

Vincent QB's avatar
Vincent QB committed
272
    specgram = specgram.pow(1 / power)
273

274
    # initialize the phase
275
    if rand_init:
276
        angles = torch.rand(specgram.size(), dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
277
    else:
278
        angles = torch.full(specgram.size(), 1, dtype=_get_complex_dtype(specgram.dtype), device=specgram.device)
279
280

    # And initialize the previous iterate to 0
281
    tprev = torch.tensor(0.0, dtype=specgram.dtype, device=specgram.device)
282
283
    for _ in range(n_iter):
        # Invert with our current estimate of the phases
284
285
286
        inverse = torch.istft(
            specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
        )
287
288

        # Rebuild the spectrogram
289
290
291
292
293
294
295
        rebuilt = torch.stft(
            input=inverse,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            center=True,
296
            pad_mode="reflect",
297
298
299
            normalized=False,
            onesided=True,
            return_complex=True,
300
        )
301
302

        # Update our phase estimates
303
304
305
        angles = rebuilt
        if momentum:
            angles = angles - tprev.mul_(momentum / (1 + momentum))
306
307
308
309
        angles = angles.div(angles.abs().add(1e-16))

        # Store the previous iterate
        tprev = rebuilt
310
311

    # Return the final phase estimates
312
313
314
    waveform = torch.istft(
        specgram * angles, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=length
    )
Vincent QB's avatar
Vincent QB committed
315
316

    # unpack batch
317
    waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
Vincent QB's avatar
Vincent QB committed
318
319

    return waveform
320
321


322
def amplitude_to_DB(
323
    x: Tensor, multiplier: float, amin: float, db_multiplier: float, top_db: Optional[float] = None
324
) -> Tensor:
325
    r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
326

moto's avatar
moto committed
327
328
329
330
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

331
332
    The output of each tensor in a batch depends on the maximum value of that tensor,
    and so may return different values for an audio clip split into snippets vs. a full clip.
333
334

    Args:
335
336
337
338

        x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
          the form `(..., freq, time)`. Batched inputs should include a channel dimension and
          have the form `(batch, channel, freq, time)`.
339
        multiplier (float): Use 10. for power and 20. for amplitude
340
        amin (float): Number to clamp ``x``
341
        db_multiplier (float): Log10(max(reference value and amin))
342
        top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
343
            is 80. (Default: ``None``)
344
345

    Returns:
346
        Tensor: Output tensor in decibel scale
347
    """
348
349
    x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
    x_db -= multiplier * db_multiplier
350
351

    if top_db is not None:
352
353
354
355
356
357
358
359
360
        # Expand batch
        shape = x_db.size()
        packed_channels = shape[-3] if x_db.dim() > 2 else 1
        x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])

        x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))

        # Repack batch
        x_db = x_db.reshape(shape)
361

362
    return x_db
363
364


365
def DB_to_amplitude(x: Tensor, ref: float, power: float) -> Tensor:
366
367
    r"""Turn a tensor from the decibel scale to the power/amplitude scale.

moto's avatar
moto committed
368
369
370
371
    .. devices:: CPU CUDA

    .. properties:: TorchScript

372
    Args:
373
        x (Tensor): Input tensor before being converted to power/amplitude scale.
374
375
376
377
        ref (float): Reference which the output will be scaled by.
        power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.

    Returns:
378
        Tensor: Output tensor in power/amplitude scale.
379
380
381
382
    """
    return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)


383
384
385
386
387
388
389
390
391
392
393
def _hz_to_mel(freq: float, mel_scale: str = "htk") -> float:
    r"""Convert Hz to Mels.

    Args:
        freqs (float): Frequencies in Hz
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        mels (float): Frequency in Mels
    """

394
    if mel_scale not in ["slaney", "htk"]:
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
        return 2595.0 * math.log10(1.0 + (freq / 700.0))

    # Fill in the linear part
    f_min = 0.0
    f_sp = 200.0 / 3

    mels = (freq - f_min) / f_sp

    # Fill in the log-scale part
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

    if freq >= min_log_hz:
        mels = min_log_mel + math.log(freq / min_log_hz) / logstep

    return mels


def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
    """Convert mel bin numbers to frequencies.

    Args:
        mels (Tensor): Mel frequencies
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        freqs (Tensor): Mels converted in Hz
    """

428
    if mel_scale not in ["slaney", "htk"]:
429
430
431
        raise ValueError('mel_scale should be one of "htk" or "slaney".')

    if mel_scale == "htk":
432
        return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
433
434
435
436
437
438
439
440
441
442
443

    # Fill in the linear scale
    f_min = 0.0
    f_sp = 200.0 / 3
    freqs = f_min + f_sp * mels

    # And now the nonlinear scale
    min_log_hz = 1000.0
    min_log_mel = (min_log_hz - f_min) / f_sp
    logstep = math.log(6.4) / 27.0

444
    log_t = mels >= min_log_mel
445
446
447
448
449
    freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

    return freqs


450
def _create_triangular_filterbank(
451
452
    all_freqs: Tensor,
    f_pts: Tensor,
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
) -> Tensor:
    """Create a triangular filter bank.

    Args:
        all_freqs (Tensor): STFT freq points of size (`n_freqs`).
        f_pts (Tensor): Filter mid points of size (`n_filter`).

    Returns:
        fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
    """
    # Adopted from Librosa
    # calculate the difference between each filter mid point and each stft freq point in hertz
    f_diff = f_pts[1:] - f_pts[:-1]  # (n_filter + 1)
    slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)  # (n_freqs, n_filter + 2)
    # create overlapping triangles
    zero = torch.zeros(1)
    down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1]  # (n_freqs, n_filter)
    up_slopes = slopes[:, 2:] / f_diff[1:]  # (n_freqs, n_filter)
    fb = torch.max(zero, torch.min(down_slopes, up_slopes))

    return fb


476
def melscale_fbanks(
477
478
479
480
481
482
483
    n_freqs: int,
    f_min: float,
    f_max: float,
    n_mels: int,
    sample_rate: int,
    norm: Optional[str] = None,
    mel_scale: str = "htk",
484
485
486
) -> Tensor:
    r"""Create a frequency bin conversion matrix.

moto's avatar
moto committed
487
488
489
490
    .. devices:: CPU

    .. properties:: TorchScript

moto's avatar
moto committed
491
492
493
494
495
496
497
    Note:
        For the sake of the numerical compatibility with librosa, not all the coefficients
        in the resulting filter bank has magnitude of 1.

        .. image:: https://download.pytorch.org/torchaudio/doc-assets/mel_fbanks.png
           :alt: Visualization of generated filter bank

498
499
500
501
502
503
    Args:
        n_freqs (int): Number of frequencies to highlight/apply
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
        n_mels (int): Number of mel filterbanks
        sample_rate (int): Sample rate of the audio waveform
504
505
        norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band
            (area normalization). (Default: ``None``)
506
507
508
509
510
511
512
513
        mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)

    Returns:
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
        size (..., ``n_freqs``), the applied result would be
        ``A * melscale_fbanks(A.size(-1), ...)``.
moto's avatar
moto committed
514

515
    """
516
517
518
519

    if norm is not None and norm != "slaney":
        raise ValueError("norm must be one of None or 'slaney'")

520
    # freq bins
engineerchuan's avatar
engineerchuan committed
521
522
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

Jason Lian's avatar
more  
Jason Lian committed
523
    # calculate mel freq bins
524
525
526
    m_min = _hz_to_mel(f_min, mel_scale=mel_scale)
    m_max = _hz_to_mel(f_max, mel_scale=mel_scale)

Jason Lian's avatar
more  
Jason Lian committed
527
    m_pts = torch.linspace(m_min, m_max, n_mels + 2)
528
529
    f_pts = _mel_to_hz(m_pts, mel_scale=mel_scale)

530
531
    # create filterbank
    fb = _create_triangular_filterbank(all_freqs, f_pts)
Vincent QB's avatar
Vincent QB committed
532

533
    if norm is not None and norm == "slaney":
Vincent QB's avatar
Vincent QB committed
534
        # Slaney-style mel is scaled to be approx constant energy per channel
535
        enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
Vincent QB's avatar
Vincent QB committed
536
537
        fb *= enorm.unsqueeze(0)

538
    if (fb.max(dim=0).values == 0.0).any():
539
540
541
542
543
544
        warnings.warn(
            "At least one mel filterbank has all zero values. "
            f"The value for `n_mels` ({n_mels}) may be set too high. "
            f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
        )

Jason Lian's avatar
more  
Jason Lian committed
545
546
547
    return fb


548
def linear_fbanks(
549
550
551
552
553
    n_freqs: int,
    f_min: float,
    f_max: float,
    n_filter: int,
    sample_rate: int,
554
555
556
) -> Tensor:
    r"""Creates a linear triangular filterbank.

moto's avatar
moto committed
557
558
559
560
    .. devices:: CPU

    .. properties:: TorchScript

moto's avatar
moto committed
561
562
563
564
565
566
567
    Note:
        For the sake of the numerical compatibility with librosa, not all the coefficients
        in the resulting filter bank has magnitude of 1.

        .. image:: https://download.pytorch.org/torchaudio/doc-assets/lin_fbanks.png
           :alt: Visualization of generated filter bank

568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    Args:
        n_freqs (int): Number of frequencies to highlight/apply
        f_min (float): Minimum frequency (Hz)
        f_max (float): Maximum frequency (Hz)
        n_filter (int): Number of (linear) triangular filter
        sample_rate (int): Sample rate of the audio waveform

    Returns:
        Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_filter``)
        meaning number of frequencies to highlight/apply to x the number of filterbanks.
        Each column is a filterbank so that assuming there is a matrix A of
        size (..., ``n_freqs``), the applied result would be
        ``A * linear_fbanks(A.size(-1), ...)``.
    """
    # freq bins
    all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)

    # filter mid-points
    f_pts = torch.linspace(f_min, f_max, n_filter + 2)

    # create filterbank
    fb = _create_triangular_filterbank(all_freqs, f_pts)

    return fb


594
def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
Vincent QB's avatar
Vincent QB committed
595
    r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
jamarshon's avatar
jamarshon committed
596
    normalized depending on norm.
Jason Lian's avatar
Jason Lian committed
597

moto's avatar
moto committed
598
599
600
601
    .. devices:: CPU

    .. properties:: TorchScript

jamarshon's avatar
jamarshon committed
602
    Args:
603
604
        n_mfcc (int): Number of mfc coefficients to retain
        n_mels (int): Number of mel filterbanks
605
        norm (str or None): Norm to use (either 'ortho' or None)
Jason Lian's avatar
Jason Lian committed
606

jamarshon's avatar
jamarshon committed
607
    Returns:
608
        Tensor: The transformation matrix, to be right-multiplied to
609
        row-wise data of size (``n_mels``, ``n_mfcc``).
Jason Lian's avatar
more  
Jason Lian committed
610
611
    """
    # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
612
613
614
    n = torch.arange(float(n_mels))
    k = torch.arange(float(n_mfcc)).unsqueeze(1)
    dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k)  # size (n_mfcc, n_mels)
615
616
    if norm is None:
        dct *= 2.0
Jason Lian's avatar
more  
Jason Lian committed
617
    else:
618
        assert norm == "ortho"
619
        dct[0] *= 1.0 / math.sqrt(2.0)
620
        dct *= math.sqrt(2.0 / float(n_mels))
621
    return dct.t()
Jason Lian's avatar
more  
Jason Lian committed
622
623


624
def mu_law_encoding(x: Tensor, quantization_channels: int) -> Tensor:
moto's avatar
moto committed
625
626
627
628
629
630
631
    r"""Encode signal based on mu-law companding.

    .. devices:: CPU CUDA

    .. properties:: TorchScript

    For more info see the
Jason Lian's avatar
Jason Lian committed
632
633
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

634
    This algorithm expects the signal has been scaled to between -1 and 1 and
jamarshon's avatar
jamarshon committed
635
    returns a signal encoded with values from 0 to quantization_channels - 1.
Jason Lian's avatar
Jason Lian committed
636

jamarshon's avatar
jamarshon committed
637
    Args:
638
        x (Tensor): Input tensor
639
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
640

jamarshon's avatar
jamarshon committed
641
    Returns:
642
        Tensor: Input after mu-law encoding
Jason Lian's avatar
Jason Lian committed
643
    """
644
    mu = quantization_channels - 1.0
645
    if not x.is_floating_point():
646
647
648
649
        warnings.warn(
            "The input Tensor must be of floating type. \
            This will be an error in the v0.12 release."
        )
650
651
        x = x.to(torch.float)
    mu = torch.tensor(mu, dtype=x.dtype)
652
    x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
Jason Lian's avatar
Jason Lian committed
653
    x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
Jason Lian's avatar
more  
Jason Lian committed
654
655
656
    return x_mu


657
def mu_law_decoding(x_mu: Tensor, quantization_channels: int) -> Tensor:
moto's avatar
moto committed
658
659
660
661
662
663
664
    r"""Decode mu-law encoded signal.

    .. devices:: CPU CUDA

    .. properties:: TorchScript

    For more info see the
Jason Lian's avatar
Jason Lian committed
665
666
667
668
669
    `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_

    This expects an input with values between 0 and quantization_channels - 1
    and returns a signal scaled between -1 and 1.

jamarshon's avatar
jamarshon committed
670
    Args:
671
        x_mu (Tensor): Input tensor
672
        quantization_channels (int): Number of channels
Jason Lian's avatar
Jason Lian committed
673

jamarshon's avatar
jamarshon committed
674
    Returns:
675
        Tensor: Input after mu-law decoding
Jason Lian's avatar
Jason Lian committed
676
    """
677
    mu = quantization_channels - 1.0
678
    if not x_mu.is_floating_point():
679
680
        x_mu = x_mu.to(torch.float)
    mu = torch.tensor(mu, dtype=x_mu.dtype)
681
682
    x = ((x_mu) / mu) * 2 - 1.0
    x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
Jason Lian's avatar
more  
Jason Lian committed
683
    return x
684
685


686
def phase_vocoder(complex_specgrams: Tensor, rate: float, phase_advance: Tensor) -> Tensor:
moto's avatar
moto committed
687
688
689
690
691
    r"""Given a STFT tensor, speed up in time without modifying pitch by a factor of ``rate``.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript
Vincent QB's avatar
Vincent QB committed
692

693
    Args:
694
        complex_specgrams (Tensor):
695
            A tensor of dimension `(..., freq, num_frame)` with complex dtype.
696
        rate (float): Speed-up factor
697
        phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`
Vincent QB's avatar
Vincent QB committed
698

699
    Returns:
700
701
702
        Tensor:
            Stretched spectrogram. The resulting tensor is of the same dtype as the input
            spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
Vincent QB's avatar
Vincent QB committed
703

704
    Example
705
706
707
708
709
710
711
712
713
        >>> freq, hop_length = 1025, 512
        >>> # (channel, freq, time)
        >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
        >>> rate = 1.3 # Speed up by 30%
        >>> phase_advance = torch.linspace(
        >>>    0, math.pi * hop_length, freq)[..., None]
        >>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
        >>> x.shape # with 231 == ceil(300 / 1.3)
        torch.Size([2, 1025, 231])
714
    """
715
716
717
    if rate == 1.0:
        return complex_specgrams

718
719
    # pack batch
    shape = complex_specgrams.size()
720
721
722
723
724
    complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))

    # Figures out the corresponding real dtype, i.e. complex128 -> float64, complex64 -> float32
    # Note torch.real is a view so it does not incur any memory copy.
    real_dtype = torch.real(complex_specgrams).dtype
725
    time_steps = torch.arange(0, complex_specgrams.size(-1), rate, device=complex_specgrams.device, dtype=real_dtype)
726

727
    alphas = time_steps % 1.0
728
    phase_0 = complex_specgrams[..., :1].angle()
729
730

    # Time Padding
731
    complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 2])
732

733
    # (new_bins, freq, 2)
734
735
    complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long())
    complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long())
736

737
738
    angle_0 = complex_specgrams_0.angle()
    angle_1 = complex_specgrams_1.angle()
739

740
741
    norm_0 = complex_specgrams_0.abs()
    norm_1 = complex_specgrams_1.abs()
742
743
744
745
746
747

    phase = angle_1 - angle_0 - phase_advance
    phase = phase - 2 * math.pi * torch.round(phase / (2 * math.pi))

    # Compute Phase Accum
    phase = phase + phase_advance
Vincent QB's avatar
Vincent QB committed
748
    phase = torch.cat([phase_0, phase[..., :-1]], dim=-1)
749
750
751
752
    phase_acc = torch.cumsum(phase, -1)

    mag = alphas * norm_1 + (1 - alphas) * norm_0

753
    complex_specgrams_stretch = torch.polar(mag, phase_acc)
754

755
    # unpack batch
756
    complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
757
    return complex_specgrams_stretch
758
759


760
761
762
763
764
765
766
767
768
769
770
771
772
773
def _get_mask_param(mask_param: int, p: float, axis_length: int) -> int:
    if p == 1.0:
        return mask_param
    else:
        return min(mask_param, int(axis_length * p))


def mask_along_axis_iid(
    specgrams: Tensor,
    mask_param: int,
    mask_value: float,
    axis: int,
    p: float = 1.0,
) -> Tensor:
moto's avatar
moto committed
774
775
776
777
778
779
780
781
782
783
784
    r"""Apply a mask along ``axis``.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

    Mask will be applied from indices ``[v_0, v_0 + v)``,
    where ``v`` is sampled from ``uniform(0, max_v)`` and
    ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``,
    with ``max_v = mask_param`` when ``p = 1.0`` and
    ``max_v = min(mask_param, floor(specgrams.size(axis) * p))`` otherwise.
785
786

    Args:
787
        specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)`
788
789
790
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
791
        p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
792
793

    Returns:
794
        Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)`
795
796
    """

797
    if axis not in [2, 3]:
798
        raise ValueError("Only Frequency and Time masking are supported")
799

800
801
802
803
804
805
806
    if not 0.0 <= p <= 1.0:
        raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")

    mask_param = _get_mask_param(mask_param, p, specgrams.shape[axis])
    if mask_param < 1:
        return specgrams

807
808
809
810
811
    device = specgrams.device
    dtype = specgrams.dtype

    value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
    min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
812
813

    # Create broadcastable mask
814
815
    mask_start = min_value.long()[..., None, None]
    mask_end = (min_value.long() + value.long())[..., None, None]
816
    mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
817
818
819

    # Per batch example masking
    specgrams = specgrams.transpose(axis, -1)
820
    specgrams = specgrams.masked_fill((mask >= mask_start) & (mask < mask_end), mask_value)
821
822
823
824
825
    specgrams = specgrams.transpose(axis, -1)

    return specgrams


826
827
828
829
830
831
832
def mask_along_axis(
    specgram: Tensor,
    mask_param: int,
    mask_value: float,
    axis: int,
    p: float = 1.0,
) -> Tensor:
moto's avatar
moto committed
833
834
835
836
837
838
839
840
841
842
843
844
845
    r"""Apply a mask along ``axis``.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

    Mask will be applied from indices ``[v_0, v_0 + v)``,
    where ``v`` is sampled from ``uniform(0, max_v)`` and
    ``v_0`` from ``uniform(0, specgrams.size(axis) - v)``, with
    ``max_v = mask_param`` when ``p = 1.0`` and
    ``max_v = min(mask_param, floor(specgrams.size(axis) * p))``
    otherwise.
    All examples will have the same mask interval.
846
847

    Args:
848
        specgram (Tensor): Real spectrogram `(channel, freq, time)`
849
850
851
        mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
        mask_value (float): Value to assign to the masked columns
        axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
852
        p (float, optional): maximum proportion of columns that can be masked. (Default: 1.0)
853
854

    Returns:
855
        Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
856
    """
857
    if axis not in [1, 2]:
858
        raise ValueError("Only Frequency and Time masking are supported")
859

860
861
862
863
864
865
866
    if not 0.0 <= p <= 1.0:
        raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")

    mask_param = _get_mask_param(mask_param, p, specgram.shape[axis])
    if mask_param < 1:
        return specgram

867
868
    # pack batch
    shape = specgram.size()
869
    specgram = specgram.reshape([-1] + list(shape[-2:]))
870
871
872
873
874
    value = torch.rand(1) * mask_param
    min_value = torch.rand(1) * (specgram.size(axis) - value)

    mask_start = (min_value.long()).squeeze()
    mask_end = (min_value.long() + value.long()).squeeze()
875
876
877
878
    mask = torch.arange(0, specgram.shape[axis], device=specgram.device, dtype=specgram.dtype)
    mask = (mask >= mask_start) & (mask < mask_end)
    if axis == 1:
        mask = mask.unsqueeze(-1)
879
880

    assert mask_end - mask_start < mask_param
881
882

    specgram = specgram.masked_fill(mask, mask_value)
883

884
    # unpack batch
885
    specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
886

Vincent QB's avatar
Vincent QB committed
887
    return specgram
888
889


890
def compute_deltas(specgram: Tensor, win_length: int = 5, mode: str = "replicate") -> Tensor:
Vincent QB's avatar
Vincent QB committed
891
892
    r"""Compute delta coefficients of a tensor, usually a spectrogram:

moto's avatar
moto committed
893
894
895
896
    .. devices:: CPU CUDA

    .. properties:: TorchScript

Vincent QB's avatar
Vincent QB committed
897
    .. math::
898
       d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N}} n^2}
Vincent QB's avatar
Vincent QB committed
899
900
901

    where :math:`d_t` is the deltas at time :math:`t`,
    :math:`c_t` is the spectrogram coeffcients at time :math:`t`,
902
    :math:`N` is ``(win_length-1)//2``.
Vincent QB's avatar
Vincent QB committed
903
904

    Args:
905
        specgram (Tensor): Tensor of audio of dimension `(..., freq, time)`
906
907
        win_length (int, optional): The window length used for computing delta (Default: ``5``)
        mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Vincent QB's avatar
Vincent QB committed
908
909

    Returns:
910
        Tensor: Tensor of deltas of dimension `(..., freq, time)`
Vincent QB's avatar
Vincent QB committed
911
912
913
914
915
916

    Example
        >>> specgram = torch.randn(1, 40, 1000)
        >>> delta = compute_deltas(specgram)
        >>> delta2 = compute_deltas(delta)
    """
917
918
    device = specgram.device
    dtype = specgram.dtype
Vincent QB's avatar
Vincent QB committed
919

Vincent QB's avatar
Vincent QB committed
920
921
    # pack batch
    shape = specgram.size()
922
    specgram = specgram.reshape(1, -1, shape[-1])
Vincent QB's avatar
Vincent QB committed
923

Vincent QB's avatar
Vincent QB committed
924
925
926
927
928
929
930
931
932
    assert win_length >= 3

    n = (win_length - 1) // 2

    # twice sum of integer squared
    denom = n * (n + 1) * (2 * n + 1) / 3

    specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)

933
    kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
Vincent QB's avatar
Vincent QB committed
934

935
    output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
Vincent QB's avatar
Vincent QB committed
936
937

    # unpack batch
938
    output = output.reshape(shape)
Vincent QB's avatar
Vincent QB committed
939
940

    return output
Vincent QB's avatar
Vincent QB committed
941
942


943
def _compute_nccf(waveform: Tensor, sample_rate: int, frame_time: float, freq_low: int) -> Tensor:
Vincent QB's avatar
Vincent QB committed
944
945
946
947
948
949
950
951
952
    r"""
    Compute Normalized Cross-Correlation Function (NCCF).

    .. math::
        \phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},

    where
    :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
    :math:`w` is the waveform,
953
    :math:`N` is the length of a frame,
Vincent QB's avatar
Vincent QB committed
954
955
956
957
958
959
960
    :math:`b_i` is the beginning of frame :math:`i`,
    :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
    """

    EPSILON = 10 ** (-9)

    # Number of lags to check
961
    lags = int(math.ceil(sample_rate / freq_low))
Vincent QB's avatar
Vincent QB committed
962
963
964
965

    frame_size = int(math.ceil(sample_rate * frame_time))

    waveform_length = waveform.size()[-1]
966
    num_of_frames = int(math.ceil(waveform_length / frame_size))
Vincent QB's avatar
Vincent QB committed
967
968
969
970
971
972
973

    p = lags + num_of_frames * frame_size - waveform_length
    waveform = torch.nn.functional.pad(waveform, (0, p))

    # Compute lags
    output_lag = []
    for lag in range(1, lags + 1):
974
975
        s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
        s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
Vincent QB's avatar
Vincent QB committed
976
977
978
979
980
981
982
983
984
985
986
987
988
989

        output_frames = (
            (s1 * s2).sum(-1)
            / (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
            / (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
        )

        output_lag.append(output_frames.unsqueeze(-1))

    nccf = torch.cat(output_lag, -1)

    return nccf


990
def _combine_max(a: Tuple[Tensor, Tensor], b: Tuple[Tensor, Tensor], thresh: float = 0.99) -> Tuple[Tensor, Tensor]:
Vincent QB's avatar
Vincent QB committed
991
992
993
    """
    Take value from first if bigger than a multiplicative factor of the second, elementwise.
    """
994
    mask = a[0] > thresh * b[0]
Vincent QB's avatar
Vincent QB committed
995
996
997
998
999
    values = mask * a[0] + ~mask * b[0]
    indices = mask * a[1] + ~mask * b[1]
    return values, indices


1000
def _find_max_per_frame(nccf: Tensor, sample_rate: int, freq_high: int) -> Tensor:
Vincent QB's avatar
Vincent QB committed
1001
1002
1003
1004
1005
1006
1007
1008
    r"""
    For each frame, take the highest value of NCCF,
    apply centered median smoothing, and convert to frequency.

    Note: If the max among all the lags is very close
    to the first half of lags, then the latter is taken.
    """

1009
    lag_min = int(math.ceil(sample_rate / freq_high))
Vincent QB's avatar
Vincent QB committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028

    # Find near enough max that is smallest

    best = torch.max(nccf[..., lag_min:], -1)

    half_size = nccf.shape[-1] // 2
    half = torch.max(nccf[..., lag_min:half_size], -1)

    best = _combine_max(half, best)
    indices = best[1]

    # Add back minimal lag
    indices += lag_min
    # Add 1 empirical calibration offset
    indices += 1

    return indices


1029
def _median_smoothing(indices: Tensor, win_length: int) -> Tensor:
Vincent QB's avatar
Vincent QB committed
1030
1031
1032
1033
1034
1035
1036
1037
    r"""
    Apply median smoothing to the 1D tensor over the given window.
    """

    # Centered windowed
    pad_length = (win_length - 1) // 2

    # "replicate" padding in any dimension
1038
    indices = torch.nn.functional.pad(indices, (pad_length, 0), mode="constant", value=0.0)
Vincent QB's avatar
Vincent QB committed
1039
1040
1041
1042
1043
1044
1045
1046
1047

    indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
    roll = indices.unfold(-1, win_length, 1)

    values, _ = torch.median(roll, -1)
    return values


def detect_pitch_frequency(
1048
1049
1050
1051
1052
1053
    waveform: Tensor,
    sample_rate: int,
    frame_time: float = 10 ** (-2),
    win_length: int = 30,
    freq_low: int = 85,
    freq_high: int = 3400,
1054
) -> Tensor:
Vincent QB's avatar
Vincent QB committed
1055
1056
    r"""Detect pitch frequency.

moto's avatar
moto committed
1057
1058
1059
1060
    .. devices:: CPU CUDA

    .. properties:: TorchScript

Vincent QB's avatar
Vincent QB committed
1061
1062
1063
    It is implemented using normalized cross-correlation function and median smoothing.

    Args:
1064
        waveform (Tensor): Tensor of audio of dimension `(..., freq, time)`
Vincent QB's avatar
Vincent QB committed
1065
        sample_rate (int): The sample rate of the waveform (Hz)
1066
1067
1068
1069
        frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
        win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
        freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
        freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Vincent QB's avatar
Vincent QB committed
1070
1071

    Returns:
1072
        Tensor: Tensor of freq of dimension `(..., frame)`
Vincent QB's avatar
Vincent QB committed
1073
    """
Vincent QB's avatar
Vincent QB committed
1074
    # pack batch
1075
    shape = list(waveform.size())
1076
    waveform = waveform.reshape([-1] + shape[-1:])
Vincent QB's avatar
Vincent QB committed
1077

Vincent QB's avatar
Vincent QB committed
1078
1079
1080
1081
1082
1083
1084
1085
    nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
    indices = _find_max_per_frame(nccf, sample_rate, freq_high)
    indices = _median_smoothing(indices, win_length)

    # Convert indices to frequency
    EPSILON = 10 ** (-9)
    freq = sample_rate / (EPSILON + indices.to(torch.float))

Vincent QB's avatar
Vincent QB committed
1086
    # unpack batch
1087
    freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
Vincent QB's avatar
Vincent QB committed
1088

Vincent QB's avatar
Vincent QB committed
1089
    return freq
wanglong001's avatar
wanglong001 committed
1090
1091
1092


def sliding_window_cmn(
1093
    specgram: Tensor,
wanglong001's avatar
wanglong001 committed
1094
1095
1096
1097
1098
1099
1100
1101
    cmn_window: int = 600,
    min_cmn_window: int = 100,
    center: bool = False,
    norm_vars: bool = False,
) -> Tensor:
    r"""
    Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.

moto's avatar
moto committed
1102
1103
1104
1105
    .. devices:: CPU CUDA

    .. properties:: TorchScript

wanglong001's avatar
wanglong001 committed
1106
    Args:
1107
        specgram (Tensor): Tensor of spectrogram of dimension `(..., time, freq)`
wanglong001's avatar
wanglong001 committed
1108
1109
1110
1111
1112
1113
1114
1115
        cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
        min_cmn_window (int, optional):  Minimum CMN window used at start of decoding (adds latency only at start).
            Only applicable if center == false, ignored if center==true (int, default = 100)
        center (bool, optional): If true, use a window centered on the current frame
            (to the extent possible, modulo end effects). If false, window is to the left. (bool, default = false)
        norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)

    Returns:
1116
        Tensor: Tensor matching input shape `(..., freq, time)`
wanglong001's avatar
wanglong001 committed
1117
    """
1118
    input_shape = specgram.shape
1119
    num_frames, num_feats = input_shape[-2:]
1120
1121
    specgram = specgram.view(-1, num_frames, num_feats)
    num_channels = specgram.shape[0]
1122

1123
1124
    dtype = specgram.dtype
    device = specgram.device
wanglong001's avatar
wanglong001 committed
1125
    last_window_start = last_window_end = -1
1126
1127
    cur_sum = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
    cur_sumsq = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
1128
    cmn_specgram = torch.zeros(num_channels, num_frames, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
    for t in range(num_frames):
        window_start = 0
        window_end = 0
        if center:
            window_start = t - cmn_window // 2
            window_end = window_start + cmn_window
        else:
            window_start = t - cmn_window
            window_end = t + 1
        if window_start < 0:
            window_end -= window_start
            window_start = 0
        if not center:
            if window_end > t:
                window_end = max(t + 1, min_cmn_window)
        if window_end > num_frames:
1145
            window_start -= window_end - num_frames
wanglong001's avatar
wanglong001 committed
1146
1147
1148
1149
            window_end = num_frames
            if window_start < 0:
                window_start = 0
        if last_window_start == -1:
1150
            input_part = specgram[:, window_start : window_end - window_start, :]
1151
            cur_sum += torch.sum(input_part, 1)
wanglong001's avatar
wanglong001 committed
1152
            if norm_vars:
1153
                cur_sumsq += torch.cumsum(input_part**2, 1)[:, -1, :]
wanglong001's avatar
wanglong001 committed
1154
1155
        else:
            if window_start > last_window_start:
1156
                frame_to_remove = specgram[:, last_window_start, :]
wanglong001's avatar
wanglong001 committed
1157
1158
                cur_sum -= frame_to_remove
                if norm_vars:
1159
                    cur_sumsq -= frame_to_remove**2
wanglong001's avatar
wanglong001 committed
1160
            if window_end > last_window_end:
1161
                frame_to_add = specgram[:, last_window_end, :]
wanglong001's avatar
wanglong001 committed
1162
1163
                cur_sum += frame_to_add
                if norm_vars:
1164
                    cur_sumsq += frame_to_add**2
wanglong001's avatar
wanglong001 committed
1165
1166
1167
        window_frames = window_end - window_start
        last_window_start = window_start
        last_window_end = window_end
1168
        cmn_specgram[:, t, :] = specgram[:, t, :] - cur_sum / window_frames
wanglong001's avatar
wanglong001 committed
1169
1170
        if norm_vars:
            if window_frames == 1:
1171
                cmn_specgram[:, t, :] = torch.zeros(num_channels, num_feats, dtype=dtype, device=device)
wanglong001's avatar
wanglong001 committed
1172
1173
1174
            else:
                variance = cur_sumsq
                variance = variance / window_frames
1175
                variance -= (cur_sum**2) / (window_frames**2)
wanglong001's avatar
wanglong001 committed
1176
                variance = torch.pow(variance, -0.5)
1177
                cmn_specgram[:, t, :] *= variance
1178

1179
    cmn_specgram = cmn_specgram.view(input_shape[:-2] + (num_frames, num_feats))
1180
    if len(input_shape) == 2:
1181
1182
        cmn_specgram = cmn_specgram.squeeze(0)
    return cmn_specgram
1183
1184
1185


def spectral_centroid(
1186
1187
1188
1189
1190
1191
1192
    waveform: Tensor,
    sample_rate: int,
    pad: int,
    window: Tensor,
    n_fft: int,
    hop_length: int,
    win_length: int,
1193
) -> Tensor:
moto's avatar
moto committed
1194
1195
1196
1197
1198
    r"""Compute the spectral centroid for each channel along the time axis.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript
1199
1200
1201
1202
1203

    The spectral centroid is defined as the weighted average of the
    frequency values, weighted by their magnitude.

    Args:
1204
        waveform (Tensor): Tensor of audio of dimension `(..., time)`
1205
1206
1207
1208
1209
1210
1211
1212
        sample_rate (int): Sample rate of the audio waveform
        pad (int): Two sided padding of signal
        window (Tensor): Window tensor that is applied/multiplied to each frame/window
        n_fft (int): Size of FFT
        hop_length (int): Length of hop between STFT windows
        win_length (int): Window size

    Returns:
1213
        Tensor: Dimension `(..., time)`
1214
    """
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
    specgram = spectrogram(
        waveform,
        pad=pad,
        window=window,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        power=1.0,
        normalized=False,
    )
    freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2, device=specgram.device).reshape((-1, 1))
1226
1227
    freq_dim = -2
    return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
moto's avatar
moto committed
1228
1229


Caroline Chen's avatar
Caroline Chen committed
1230
@_mod_utils.requires_sox()
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
def apply_codec(
    waveform: Tensor,
    sample_rate: int,
    format: str,
    channels_first: bool = True,
    compression: Optional[float] = None,
    encoding: Optional[str] = None,
    bits_per_sample: Optional[int] = None,
) -> Tensor:
    r"""
Vincent QB's avatar
Vincent QB committed
1241
1242
    Apply codecs as a form of augmentation.

moto's avatar
moto committed
1243
1244
    .. devices:: CPU

1245
    Args:
Vincent QB's avatar
Vincent QB committed
1246
1247
1248
        waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```.
        sample_rate (int): Sample rate of the audio waveform.
        format (str): File format.
1249
        channels_first (bool, optional):
1250
1251
            When True, both the input and output Tensor have dimension `(channel, time)`.
            Otherwise, they have dimension `(time, channel)`.
1252
        compression (float or None, optional): Used for formats other than WAV.
Matthew Turnshek's avatar
Matthew Turnshek committed
1253
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1254
        encoding (str or None, optional): Changes the encoding for the supported formats.
Vincent QB's avatar
Vincent QB committed
1255
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1256
        bits_per_sample (int or None, optional): Changes the bit depth for the supported formats.
Vincent QB's avatar
Vincent QB committed
1257
            For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
1258
1259

    Returns:
1260
        Tensor: Resulting Tensor.
1261
        If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
1262
1263
    """
    bytes = io.BytesIO()
1264
1265
1266
    torchaudio.backend.sox_io_backend.save(
        bytes, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
    )
1267
1268
    bytes.seek(0)
    augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
1269
1270
        bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format
    )
1271
1272
1273
    return augmented


1274
@_mod_utils.requires_kaldi()
moto's avatar
moto committed
1275
def compute_kaldi_pitch(
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    waveform: torch.Tensor,
    sample_rate: float,
    frame_length: float = 25.0,
    frame_shift: float = 10.0,
    min_f0: float = 50,
    max_f0: float = 400,
    soft_min_f0: float = 10.0,
    penalty_factor: float = 0.1,
    lowpass_cutoff: float = 1000,
    resample_frequency: float = 4000,
    delta_pitch: float = 0.005,
    nccf_ballast: float = 7000,
    lowpass_filter_width: int = 1,
    upsample_filter_width: int = 5,
    max_frames_latency: int = 0,
    frames_per_chunk: int = 0,
    simulate_first_pass_online: bool = False,
    recompute_frame: int = 500,
    snip_edges: bool = True,
moto's avatar
moto committed
1295
) -> torch.Tensor:
1296
1297
    """Extract pitch based on method described in *A pitch extraction algorithm tuned
    for automatic speech recognition* [:footcite:`6854049`].
moto's avatar
moto committed
1298

moto's avatar
moto committed
1299
1300
1301
1302
    .. devices:: CPU

    .. properties:: TorchScript

moto's avatar
moto committed
1303
1304
1305
1306
1307
1308
1309
1310
    This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi.

    Args:
        waveform (Tensor):
            The input waveform of shape `(..., time)`.
        sample_rate (float):
            Sample rate of `waveform`.
        frame_length (float, optional):
moto's avatar
moto committed
1311
            Frame length in milliseconds. (default: 25.0)
moto's avatar
moto committed
1312
        frame_shift (float, optional):
moto's avatar
moto committed
1313
            Frame shift in milliseconds. (default: 10.0)
moto's avatar
moto committed
1314
        min_f0 (float, optional):
moto's avatar
moto committed
1315
            Minimum F0 to search for (Hz)  (default: 50.0)
moto's avatar
moto committed
1316
        max_f0 (float, optional):
moto's avatar
moto committed
1317
            Maximum F0 to search for (Hz)  (default: 400.0)
moto's avatar
moto committed
1318
        soft_min_f0 (float, optional):
moto's avatar
moto committed
1319
            Minimum f0, applied in soft way, must not exceed min-f0  (default: 10.0)
moto's avatar
moto committed
1320
        penalty_factor (float, optional):
moto's avatar
moto committed
1321
            Cost factor for FO change.  (default: 0.1)
moto's avatar
moto committed
1322
        lowpass_cutoff (float, optional):
moto's avatar
moto committed
1323
            Cutoff frequency for LowPass filter (Hz) (default: 1000)
moto's avatar
moto committed
1324
1325
        resample_frequency (float, optional):
            Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff.
moto's avatar
moto committed
1326
            (default: 4000)
moto's avatar
moto committed
1327
        delta_pitch( float, optional):
moto's avatar
moto committed
1328
            Smallest relative change in pitch that our algorithm measures. (default: 0.005)
moto's avatar
moto committed
1329
        nccf_ballast (float, optional):
moto's avatar
moto committed
1330
            Increasing this factor reduces NCCF for quiet frames (default: 7000)
moto's avatar
moto committed
1331
1332
        lowpass_filter_width (int, optional):
            Integer that determines filter width of lowpass filter, more gives sharper filter.
moto's avatar
moto committed
1333
            (default: 1)
moto's avatar
moto committed
1334
        upsample_filter_width (int, optional):
moto's avatar
moto committed
1335
            Integer that determines filter width when upsampling NCCF. (default: 5)
moto's avatar
moto committed
1336
1337
1338
        max_frames_latency (int, optional):
            Maximum number of frames of latency that we allow pitch tracking to introduce into
            the feature processing (affects output only if ``frames_per_chunk > 0`` and
moto's avatar
moto committed
1339
            ``simulate_first_pass_online=True``) (default: 0)
moto's avatar
moto committed
1340
        frames_per_chunk (int, optional):
moto's avatar
moto committed
1341
            The number of frames used for energy normalization. (default: 0)
moto's avatar
moto committed
1342
1343
1344
        simulate_first_pass_online (bool, optional):
            If true, the function will output features that correspond to what an online decoder
            would see in the first pass of decoding -- not the final version of the features,
moto's avatar
moto committed
1345
            which is the default. (default: False)
moto's avatar
moto committed
1346
1347
1348
1349
1350
            Relevant if ``frames_per_chunk > 0``.
        recompute_frame (int, optional):
            Only relevant for compatibility with online pitch extraction.
            A non-critical parameter; the frame at which we recompute some of the forward pointers,
            after revising our estimate of the signal energy.
moto's avatar
moto committed
1351
            Relevant if ``frames_per_chunk > 0``. (default: 500)
moto's avatar
moto committed
1352
1353
1354
        snip_edges (bool, optional):
            If this is set to false, the incomplete frames near the ending edge won't be snipped,
            so that the number of frames is the file size divided by the frame-shift.
moto's avatar
moto committed
1355
            This makes different types of features give the same number of frames. (default: True)
moto's avatar
moto committed
1356
1357

    Returns:
1358
       Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
moto's avatar
moto committed
1359
       corresponds to pitch and NCCF.
moto's avatar
moto committed
1360
1361
1362
1363
    """
    shape = waveform.shape
    waveform = waveform.reshape(-1, shape[-1])
    result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch(
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
        waveform,
        sample_rate,
        frame_length,
        frame_shift,
        min_f0,
        max_f0,
        soft_min_f0,
        penalty_factor,
        lowpass_cutoff,
        resample_frequency,
        delta_pitch,
        nccf_ballast,
        lowpass_filter_width,
        upsample_filter_width,
        max_frames_latency,
        frames_per_chunk,
        simulate_first_pass_online,
        recompute_frame,
moto's avatar
moto committed
1382
1383
1384
1385
        snip_edges,
    )
    result = result.reshape(shape[:-1] + result.shape[-2:])
    return result
1386
1387


1388
def _get_sinc_resample_kernel(
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
    orig_freq: int,
    new_freq: int,
    gcd: int,
    lowpass_filter_width: int,
    rolloff: float,
    resampling_method: str,
    beta: Optional[float],
    device: torch.device = torch.device("cpu"),
    dtype: Optional[torch.dtype] = None,
):
1399
1400

    if not (int(orig_freq) == orig_freq and int(new_freq) == new_freq):
1401
1402
1403
1404
        raise Exception(
            "Frequencies must be of integer type to ensure quality resampling computation. "
            "To work around this, manually convert both frequencies to integer values "
            "that maintain their resampling rate ratio before passing them into the function. "
1405
            "Example: To downsample a 44100 hz waveform by a factor of 8, use "
1406
1407
            "`orig_freq=8` and `new_freq=1` instead of `orig_freq=44100` and `new_freq=5512.5`. "
            "For more information, please refer to https://github.com/pytorch/audio/issues/1487."
1408
1409
        )

1410
1411
    if resampling_method not in ["sinc_interpolation", "kaiser_window"]:
        raise ValueError("Invalid resampling method: {}".format(resampling_method))
1412

1413
1414
1415
    orig_freq = int(orig_freq) // gcd
    new_freq = int(new_freq) // gcd

1416
1417
1418
1419
1420
1421
1422
    assert lowpass_filter_width > 0
    kernels = []
    base_freq = min(orig_freq, new_freq)
    # This will perform antialiasing filtering by removing the highest frequencies.
    # At first I thought I only needed this when downsampling, but when upsampling
    # you will get edge artifacts without this, as the edge is equivalent to zero padding,
    # which will add high freq artifacts.
1423
    base_freq *= rolloff
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446

    # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
    # using the sinc interpolation formula:
    #   x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
    # We can then sample the function x(t) with a different sample rate:
    #    y[j] = x(j / new_freq)
    # or,
    #    y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))

    # We see here that y[j] is the convolution of x[i] with a specific filter, for which
    # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
    # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
    # Indeed:
    # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
    #                 = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
    #                 = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
    # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
    # This will explain the F.conv1d after, with a stride of orig_freq.
    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
    # they will have a lot of almost zero values to the left or to the right...
    # There is probably a way to evaluate those filters more efficiently, but this is kept for
    # future work.
1447
1448
    idx_dtype = dtype if dtype is not None else torch.float64
    idx = torch.arange(-width, width + orig_freq, device=device, dtype=idx_dtype)
1449
1450
1451
1452

    for i in range(new_freq):
        t = (-i / new_freq + idx / orig_freq) * base_freq
        t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
1453
1454

        # we do not use built in torch windows here as we need to evaluate the window
1455
        # at specific positions, not over a regular grid.
1456
        if resampling_method == "sinc_interpolation":
1457
            window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
1458
1459
1460
1461
1462
1463
        else:
            # kaiser_window
            if beta is None:
                beta = 14.769656459379492
            beta_tensor = torch.tensor(float(beta))
            window = torch.i0(beta_tensor * torch.sqrt(1 - (t / lowpass_filter_width) ** 2)) / torch.i0(beta_tensor)
1464
        t *= math.pi
1465
        kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t)
1466
1467
1468
1469
        kernel.mul_(window)
        kernels.append(kernel)

    scale = base_freq / orig_freq
1470
1471
1472
1473
    kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
    if dtype is None:
        kernels = kernels.to(dtype=torch.float32)
    return kernels, width
1474
1475


1476
def _apply_sinc_resample_kernel(
1477
1478
1479
1480
1481
1482
    waveform: Tensor,
    orig_freq: int,
    new_freq: int,
    gcd: int,
    kernel: Tensor,
    width: int,
1483
):
1484
1485
1486
    if not waveform.is_floating_point():
        raise TypeError(f"Expected floating point type for waveform tensor, but received {waveform.dtype}.")

1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
    orig_freq = int(orig_freq) // gcd
    new_freq = int(new_freq) // gcd

    # pack batch
    shape = waveform.size()
    waveform = waveform.view(-1, shape[-1])

    num_wavs, length = waveform.shape
    waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
    resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
    resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
    target_length = int(math.ceil(new_freq * length / orig_freq))
    resampled = resampled[..., :target_length]

    # unpack batch
    resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
    return resampled


1506
def resample(
1507
1508
1509
1510
1511
1512
1513
    waveform: Tensor,
    orig_freq: int,
    new_freq: int,
    lowpass_filter_width: int = 6,
    rolloff: float = 0.99,
    resampling_method: str = "sinc_interpolation",
    beta: Optional[float] = None,
1514
) -> Tensor:
moto's avatar
moto committed
1515
    r"""Resamples the waveform at the new frequency using bandlimited interpolation. [:footcite:`RESAMPLE`].
1516

moto's avatar
moto committed
1517
1518
1519
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript
1520

Caroline Chen's avatar
Caroline Chen committed
1521
1522
1523
1524
    Note:
        ``transforms.Resample`` precomputes and reuses the resampling kernel, so using it will result in
        more efficient computation if resampling multiple waveforms with the same resampling parameters.

1525
    Args:
1526
        waveform (Tensor): The input signal of dimension `(..., time)`
1527
1528
        orig_freq (int): The original frequency of the signal
        new_freq (int): The desired frequency
1529
        lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
1530
            but less efficient. (Default: ``6``)
1531
1532
        rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
            Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
1533
        resampling_method (str, optional): The resampling method to use.
1534
            Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``)
1535
        beta (float or None, optional): The shape parameter used for kaiser window.
1536
1537

    Returns:
1538
        Tensor: The waveform at the new frequency of dimension `(..., time).`
1539
1540
1541
1542
    """

    assert orig_freq > 0.0 and new_freq > 0.0

1543
1544
1545
    if orig_freq == new_freq:
        return waveform

1546
    gcd = math.gcd(int(orig_freq), int(new_freq))
1547

1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
    kernel, width = _get_sinc_resample_kernel(
        orig_freq,
        new_freq,
        gcd,
        lowpass_filter_width,
        rolloff,
        resampling_method,
        beta,
        waveform.device,
        waveform.dtype,
    )
1559
    resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
1560
    return resampled
yangarbiter's avatar
yangarbiter committed
1561
1562
1563
1564
1565
1566
1567


@torch.jit.unused
def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
    """
    Calculate the word level edit (Levenshtein) distance between two sequences.

moto's avatar
moto committed
1568
1569
    .. devices:: CPU

yangarbiter's avatar
yangarbiter committed
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
    The function computes an edit distance allowing deletion, insertion and
    substitution. The result is an integer.

    For most applications, the two input sequences should be the same type. If
    two strings are given, the output is the edit distance between the two
    strings (character edit distance). If two lists of strings are given, the
    output is the edit distance between sentences (word edit distance). Users
    may want to normalize the output by the length of the reference sequence.

    Args:
        seq1 (Sequence): the first sequence to compare.
        seq2 (Sequence): the second sequence to compare.
    Returns:
        int: The distance between the first and second sequences.
    """
    len_sent2 = len(seq2)
    dold = list(range(len_sent2 + 1))
    dnew = [0 for _ in range(len_sent2 + 1)]

    for i in range(1, len(seq1) + 1):
        dnew[0] = i
        for j in range(1, len_sent2 + 1):
            if seq1[i - 1] == seq2[j - 1]:
                dnew[j] = dold[j - 1]
            else:
                substitution = dold[j - 1] + 1
                insertion = dnew[j - 1] + 1
                deletion = dold[j] + 1
                dnew[j] = min(substitution, insertion, deletion)

        dnew, dold = dold, dnew

    return int(dold[-1])
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617


def pitch_shift(
    waveform: Tensor,
    sample_rate: int,
    n_steps: int,
    bins_per_octave: int = 12,
    n_fft: int = 512,
    win_length: Optional[int] = None,
    hop_length: Optional[int] = None,
    window: Optional[Tensor] = None,
) -> Tensor:
    """
    Shift the pitch of a waveform by ``n_steps`` steps.

moto's avatar
moto committed
1618
1619
1620
1621
    .. devices:: CPU CUDA

    .. properties:: TorchScript

1622
1623
    Args:
        waveform (Tensor): The input waveform of shape `(..., time)`.
1624
        sample_rate (int): Sample rate of `waveform`.
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
        n_steps (int): The (fractional) steps to shift `waveform`.
        bins_per_octave (int, optional): The number of steps per octave (Default: ``12``).
        n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``).
        win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``).
        hop_length (int or None, optional): Length of hop between STFT windows. If None, then
            ``win_length // 4`` is used (Default: ``None``).
        window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window.
            If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).


    Returns:
        Tensor: The pitch-shifted audio waveform of shape `(..., time)`.
    """
    if hop_length is None:
        hop_length = n_fft // 4
    if win_length is None:
        win_length = n_fft
    if window is None:
        window = torch.hann_window(window_length=win_length, device=waveform.device)

    # pack batch
    shape = waveform.size()
    waveform = waveform.reshape(-1, shape[-1])

    ori_len = shape[-1]
    rate = 2.0 ** (-float(n_steps) / bins_per_octave)
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
    spec_f = torch.stft(
        input=waveform,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        window=window,
        center=True,
        pad_mode="reflect",
        normalized=False,
        onesided=True,
        return_complex=True,
    )
1663
1664
1665
    phase_advance = torch.linspace(0, math.pi * hop_length, spec_f.shape[-2], device=spec_f.device)[..., None]
    spec_stretch = phase_vocoder(spec_f, rate, phase_advance)
    len_stretch = int(round(ori_len / rate))
1666
1667
1668
    waveform_stretch = torch.istft(
        spec_stretch, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window=window, length=len_stretch
    )
1669
    waveform_shift = resample(waveform_stretch, int(sample_rate / rate), sample_rate)
1670
1671
1672
1673
1674
1675
1676
1677
1678
    shift_len = waveform_shift.size()[-1]
    if shift_len > ori_len:
        waveform_shift = waveform_shift[..., :ori_len]
    else:
        waveform_shift = torch.nn.functional.pad(waveform_shift, [0, ori_len - shift_len])

    # unpack batch
    waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
    return waveform_shift
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691


def rnnt_loss(
    logits: Tensor,
    targets: Tensor,
    logit_lengths: Tensor,
    target_lengths: Tensor,
    blank: int = -1,
    clamp: float = -1,
    reduction: str = "mean",
):
    """Compute the RNN Transducer loss from *Sequence Transduction with Recurrent Neural Networks*
    [:footcite:`graves2012sequence`].
moto's avatar
moto committed
1692
1693
1694
1695
1696

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

1697
1698
1699
1700
1701
    The RNN Transducer loss extends the CTC loss by defining a distribution over output
    sequences of all lengths, and by jointly modelling both input-output and output-output
    dependencies.

    Args:
1702
        logits (Tensor): Tensor of dimension `(batch, max seq length, max target length + 1, class)`
1703
            containing output from joiner
1704
1705
1706
        targets (Tensor): Tensor of dimension `(batch, max target length)` containing targets with zero padded
        logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder
        target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence
1707
1708
1709
1710
1711
        blank (int, optional): blank label (Default: ``-1``)
        clamp (float, optional): clamp for gradients (Default: ``-1``)
        reduction (string, optional): Specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
    Returns:
1712
        Tensor: Loss with the reduction option applied. If ``reduction`` is  ``'none'``, then size `(batch)`,
1713
1714
        otherwise scalar.
    """
1715
    if reduction not in ["none", "mean", "sum"]:
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
        raise ValueError("reduction should be one of 'none', 'mean', or 'sum'")

    if blank < 0:  # reinterpret blank index if blank < 0.
        blank = logits.shape[-1] + blank

    costs, _ = torch.ops.torchaudio.rnnt_loss(
        logits=logits,
        targets=targets,
        logit_lengths=logit_lengths,
        target_lengths=target_lengths,
        blank=blank,
        clamp=clamp,
    )

1730
    if reduction == "mean":
1731
        return costs.mean()
1732
    elif reduction == "sum":
1733
1734
1735
        return costs.sum()

    return costs
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745


def psd(
    specgram: Tensor,
    mask: Optional[Tensor] = None,
    normalize: bool = True,
    eps: float = 1e-10,
) -> Tensor:
    """Compute cross-channel power spectral density (PSD) matrix.

moto's avatar
moto committed
1746
1747
1748
1749
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

1750
    Args:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1751
1752
1753
1754
1755
        specgram (torch.Tensor): Multi-channel complex-valued spectrum.
            Tensor with dimensions `(..., channel, freq, time)`.
        mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
            Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
            with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
1756
            (Default: ``None``)
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1757
1758
        normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
        eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
1759
1760

    Returns:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1761
1762
        torch.Tensor: The complex-valued PSD matrix of the input spectrum.
        Tensor with dimensions `(..., freq, channel, channel)`
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
    """
    specgram = specgram.transpose(-3, -2)  # shape (freq, channel, time)
    # outer product:
    # (..., ch_1, time) x (..., ch_2, time) -> (..., time, ch_1, ch_2)
    psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])

    if mask is not None:
        # Normalized mask along time dimension:
        if normalize:
            mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)

        psd = psd * mask[..., None, None]

    psd = psd.sum(dim=-3)
    return psd
1778
1779
1780
1781
1782
1783


def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
    r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.

    Args:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1784
1785
1786
1787
1788
        input (torch.Tensor): Tensor with dimensions `(..., channel, channel)`.
        dim1 (int, optional): The first dimension of the diagonal matrix.
            (Default: ``-1``)
        dim2 (int, optional): The second dimension of the diagonal matrix.
            (Default: ``-2``)
1789
1790

    Returns:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1791
        Tensor: The trace of the input Tensor.
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
    """
    assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
    assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
    input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
    return input.sum(dim=-1)


def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
    """Perform Tikhonov regularization (only modifying real part).

    Args:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1803
1804
1805
        mat (torch.Tensor): Input matrix with dimensions `(..., channel, channel)`.
        reg (float, optional): Regularization factor. (Default: 1e-8)
        eps (float, optional): Value to avoid the correlation matrix is all-zero. (Default: ``1e-8``)
1806
1807

    Returns:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1808
        Tensor: Regularized matrix with dimensions `(..., channel, channel)`.
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
    """
    # Add eps
    C = mat.size(-1)
    eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
    epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
    # in case that correlation_matrix is all-zero
    epsilon = epsilon + eps
    mat = mat + epsilon * eye[..., :, :]
    return mat


def mvdr_weights_souden(
    psd_s: Tensor,
    psd_n: Tensor,
    reference_channel: Union[int, Tensor],
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Tensor:
    r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
    by the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].

moto's avatar
moto committed
1831
1832
1833
1834
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

Zhaoheng Ni's avatar
Zhaoheng Ni committed
1835
1836
1837
1838
1839
    Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`,
    the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
    reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
    :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:

1840
1841
1842
1843
1844
1845
    .. math::
        \textbf{w}_{\text{MVDR}}(f) =
        \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
        {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}

    Args:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1846
1847
1848
1849
1850
1851
1852
        psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
            Tensor with dimensions `(..., freq, channel, channel)`.
        psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
            Tensor with dimensions `(..., freq, channel, channel)`.
        reference_channel (int or torch.Tensor): Specifies the reference channel.
            If the dtype is ``int``, it represents the reference channel index.
            If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
1853
            is one-hot.
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1854
        diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
1855
            (Default: ``True``)
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1856
1857
1858
1859
        diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
            It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
        eps (float, optional): Value to add to the denominator in the beamforming weight formula.
            (Default: ``1e-8``)
1860
1861

    Returns:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1862
        torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
1863
1864
    """
    if diagonal_loading:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1865
        psd_n = _tik_reg(psd_n, reg=diag_eps)
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
    numerator = torch.linalg.solve(psd_n, psd_s)  # psd_n.inv() @ psd_s
    # ws: (..., C, C) / (...,) -> (..., C, C)
    ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
    if torch.jit.isinstance(reference_channel, int):
        beamform_weights = ws[..., :, reference_channel]
    elif torch.jit.isinstance(reference_channel, Tensor):
        reference_channel = reference_channel.to(psd_n.dtype)
        # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
        beamform_weights = torch.einsum("...c,...c->...", [ws, reference_channel[..., None, None, :]])
    else:
        raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")

    return beamform_weights
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891


def mvdr_weights_rtf(
    rtf: Tensor,
    psd_n: Tensor,
    reference_channel: Optional[Union[int, Tensor]] = None,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
    eps: float = 1e-8,
) -> Tensor:
    r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
    based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.

moto's avatar
moto committed
1892
1893
1894
1895
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

Zhaoheng Ni's avatar
Zhaoheng Ni committed
1896
1897
1898
1899
1900
    Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`,
    the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
    reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
    :math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:

1901
1902
1903
1904
    .. math::
        \textbf{w}_{\text{MVDR}}(f) =
        \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
        {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1905
1906

    where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
1907
1908

    Args:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1909
1910
1911
1912
1913
1914
1915
        rtf (torch.Tensor): The complex-valued RTF vector of target speech.
            Tensor with dimensions `(..., freq, channel)`.
        psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
            Tensor with dimensions `(..., freq, channel, channel)`.
        reference_channel (int or torch.Tensor): Specifies the reference channel.
            If the dtype is ``int``, it represents the reference channel index.
            If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
1916
            is one-hot.
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1917
        diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
1918
            (Default: ``True``)
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1919
1920
1921
1922
        diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
            It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
        eps (float, optional): Value to add to the denominator in the beamforming weight formula.
            (Default: ``1e-8``)
1923
1924

    Returns:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1925
        torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
1926
1927
    """
    if diagonal_loading:
Zhaoheng Ni's avatar
Zhaoheng Ni committed
1928
        psd_n = _tik_reg(psd_n, reg=diag_eps)
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
    # numerator = psd_n.inv() @ stv
    numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1)  # (..., freq, channel)
    # denominator = stv^H @ psd_n.inv() @ stv
    denominator = torch.einsum("...d,...d->...", [rtf.conj(), numerator])
    beamform_weights = numerator / (denominator.real.unsqueeze(-1) + eps)
    # normalize the numerator
    if reference_channel is not None:
        if torch.jit.isinstance(reference_channel, int):
            scale = rtf[..., reference_channel].conj()
        elif torch.jit.isinstance(reference_channel, Tensor):
            reference_channel = reference_channel.to(psd_n.dtype)
            scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]])
        else:
            raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")

        beamform_weights = beamform_weights * scale[..., None]

    return beamform_weights
1947
1948
1949
1950
1951


def rtf_evd(psd_s: Tensor) -> Tensor:
    r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.

moto's avatar
moto committed
1952
1953
1954
1955
    .. devices:: CPU CUDA

    .. properties:: TorchScript

1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
    Args:
        psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
            Tensor of dimension `(..., freq, channel, channel)`

    Returns:
        Tensor: The estimated complex-valued RTF of target speech.
        Tensor of dimension `(..., freq, channel)`
    """
    _, v = torch.linalg.eigh(psd_s)  # v is sorted along with eigenvalues in ascending order
    rtf = v[..., -1]  # choose the eigenvector with max eigenvalue
    return rtf
1967
1968


1969
1970
1971
1972
1973
1974
1975
1976
def rtf_power(
    psd_s: Tensor,
    psd_n: Tensor,
    reference_channel: Union[int, Tensor],
    n_iter: int = 3,
    diagonal_loading: bool = True,
    diag_eps: float = 1e-7,
) -> Tensor:
1977
1978
    r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.

moto's avatar
moto committed
1979
1980
1981
1982
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

1983
    Args:
1984
1985
1986
1987
1988
1989
1990
        psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
            Tensor with dimensions `(..., freq, channel, channel)`.
        psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
            Tensor with dimensions `(..., freq, channel, channel)`.
        reference_channel (int or torch.Tensor): Specifies the reference channel.
            If the dtype is ``int``, it represents the reference channel index.
            If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
1991
            is one-hot.
1992
1993
1994
1995
        diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
            (Default: ``True``)
        diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
            It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
1996
1997

    Returns:
1998
1999
        torch.Tensor: The estimated complex-valued RTF of target speech.
        Tensor of dimension `(..., freq, channel)`.
2000
2001
    """
    assert n_iter > 0, "The number of iteration must be greater than 0."
2002
2003
2004
    # Apply diagonal loading to psd_n to improve robustness.
    if diagonal_loading:
        psd_n = _tik_reg(psd_n, reg=diag_eps)
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
    # phi is regarded as the first iteration
    phi = torch.linalg.solve(psd_n, psd_s)  # psd_n.inv() @ psd_s
    if torch.jit.isinstance(reference_channel, int):
        rtf = phi[..., reference_channel]
    elif torch.jit.isinstance(reference_channel, Tensor):
        reference_channel = reference_channel.to(psd_n.dtype)
        rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]])
    else:
        raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
    rtf = rtf.unsqueeze(-1)  # (..., freq, channel, 1)
    if n_iter >= 2:
        # The number of iterations in the for loop is `n_iter - 2`
        # because the `phi` above and `torch.matmul(psd_s, rtf)` are regarded as
        # two iterations.
        for _ in range(n_iter - 2):
            rtf = torch.matmul(phi, rtf)
        rtf = torch.matmul(psd_s, rtf)
    else:
        # if there is only one iteration, the rtf is the psd_s[..., referenc_channel]
        # which is psd_n @ phi @ ref_channel
        rtf = torch.matmul(psd_n, rtf)
    return rtf.squeeze(-1)
2027
2028
2029
2030
2031


def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
    r"""Apply the beamforming weight to the multi-channel noisy spectrum to obtain the single-channel enhanced spectrum.

moto's avatar
moto committed
2032
2033
2034
2035
    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
    .. math::
        \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
    where :math:`\textbf{w}_{\text{bf}}(f)` is the beamforming weight for the :math:`f`-th frequency bin,
    :math:`\textbf{Y}` is the multi-channel spectrum for the :math:`f`-th frequency bin.

    Args:
        beamform_weights (Tensor): The complex-valued beamforming weight matrix.
            Tensor of dimension `(..., freq, channel)`
        specgram (Tensor): The multi-channel complex-valued noisy spectrum.
            Tensor of dimension `(..., channel, freq, time)`

    Returns:
        Tensor: The single-channel complex-valued enhanced spectrum.
            Tensor of dimension `(..., freq, time)`
    """
    # (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
    specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
    return specgram_enhanced