Unverified Commit 48d2b572 authored by moto's avatar moto Committed by GitHub
Browse files

Migrate torch.rfft to torch.fft.rfft and cfloat (#941)

parent b7c17f80
"""Compatibility module for fft-related functions
In PyTorch 1.7, the new `torch.fft` module was introduced.
To use this new module, one has to explicitly import `torch.fft`. however this will change
the reference `torch.fft` is pointing from function to module.
And this change takes effect not only in the client code but also in already-imported libraries too.
Similarly, if a library does the explicit import, the rest of the application code must use the
`torch.fft.fft` function.
For this reason, to migrate the deprecated functions of fft-family, we need to use the new
implementation under `torch.fft` without explicitly importing `torch.fft` module.
This module provides a simple interface for the migration, abstracting away
the access to the underlying C functions.
Once the deprecated functions are removed from PyTorch and `torch.fft` starts to always represent
the new module, we can get rid of this module and call functions under `torch.fft` directly.
"""
from typing import Optional
import torch
def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor:
# see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft
return torch._C._fft.fft_rfft(input, n, dim, norm)
...@@ -2,9 +2,11 @@ from typing import Tuple ...@@ -2,9 +2,11 @@ from typing import Tuple
import math import math
import torch import torch
import torchaudio
from torch import Tensor from torch import Tensor
import torchaudio
import torchaudio._internal.fft
__all__ = [ __all__ = [
'get_mel_banks', 'get_mel_banks',
'inverse_mel_scale', 'inverse_mel_scale',
...@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor, ...@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1, 2) # size (m, padded_window_size // 2 + 1, 2)
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) fft = torchaudio._internal.fft.rfft(strided_input)
# Convert the FFT into a power spectrum # Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1) power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
...@@ -570,12 +572,10 @@ def fbank(waveform: Tensor, ...@@ -570,12 +572,10 @@ def fbank(waveform: Tensor,
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1, 2) # size (m, padded_window_size // 2 + 1)
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) spectrum = torchaudio._internal.fft.rfft(strided_input).abs()
if use_power:
power_spectrum = fft.pow(2).sum(2) # size (m, padded_window_size // 2 + 1) spectrum = spectrum.pow(2.)
if not use_power:
power_spectrum = power_spectrum.pow(0.5)
# size (num_mel_bins, padded_window_size // 2) # size (num_mel_bins, padded_window_size // 2)
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency, mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency,
...@@ -586,7 +586,7 @@ def fbank(waveform: Tensor, ...@@ -586,7 +586,7 @@ def fbank(waveform: Tensor,
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = torch.mm(power_spectrum, mel_energies.T) mel_energies = torch.mm(spectrum, mel_energies.T)
if use_log_fbank: if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering) # avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
......
...@@ -6,6 +6,7 @@ import warnings ...@@ -6,6 +6,7 @@ import warnings
import torch import torch
from torch import Tensor from torch import Tensor
import torchaudio._internal.fft
__all__ = [ __all__ = [
"spectrogram", "spectrogram",
...@@ -2073,7 +2074,7 @@ def _measure( ...@@ -2073,7 +2074,7 @@ def _measure(
dftBuf[measure_len_ws:dft_len_ws].zero_() dftBuf[measure_len_ws:dft_len_ws].zero_()
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf); # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
_dftBuf = torch.rfft(dftBuf, 1) _dftBuf = torchaudio._internal.fft.rfft(dftBuf)
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf)); # memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_() _dftBuf[:spectrum_start].zero_()
...@@ -2082,7 +2083,7 @@ def _measure( ...@@ -2082,7 +2083,7 @@ def _measure(
if boot_count >= 0 \ if boot_count >= 0 \
else measure_smooth_time_mult else measure_smooth_time_mult
_d = complex_norm(_dftBuf[spectrum_start:spectrum_end]) _d = _dftBuf[spectrum_start:spectrum_end].abs()
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult)) spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
_d = spectrum[spectrum_start:spectrum_end] ** 2 _d = spectrum[spectrum_start:spectrum_end] ** 2
...@@ -2106,12 +2107,9 @@ def _measure( ...@@ -2106,12 +2107,9 @@ def _measure(
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_() _cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_()
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf); # lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
_cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1) _cepstrum_Buf = torchaudio._internal.fft.rfft(_cepstrum_Buf)
result: float = float(torch.sum( result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)))
complex_norm(
_cepstrum_Buf[cepstrum_start:cepstrum_end],
power=2.0)))
result = \ result = \
math.log(result / (cepstrum_end - cepstrum_start)) \ math.log(result / (cepstrum_end - cepstrum_start)) \
if result > 0 \ if result > 0 \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment