Unverified Commit 83a312ce authored by hwangjeff's avatar hwangjeff Committed by GitHub
Browse files

Remove torchaudio._internal.fft module (#1631)

`torchaudio._internal.fft` was originally added to account for the introduction of module `torch.fft`, when `torch.fft` could refer to either a module or function. Now that `torch.fft` refers unambiguously to a module, we remove `torchaudio._internal.fft` and replace references to it with `torch.fft`.
parent f5dbb002
"""Compatibility module for fft-related functions
In PyTorch 1.7, the new `torch.fft` module was introduced.
To use this new module, one has to explicitly import `torch.fft`. however this will change
the reference `torch.fft` is pointing from function to module.
And this change takes effect not only in the client code but also in already-imported libraries too.
Similarly, if a library does the explicit import, the rest of the application code must use the
`torch.fft.fft` function.
For this reason, to migrate the deprecated functions of fft-family, we need to use the new
implementation under `torch.fft` without explicitly importing `torch.fft` module.
This module provides a simple interface for the migration, abstracting away
the access to the underlying C functions.
Once the deprecated functions are removed from PyTorch and `torch.fft` starts to always represent
the new module, we can get rid of this module and call functions under `torch.fft` directly.
"""
from typing import Optional
import torch
def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor:
# see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft
return torch._C._fft.fft_rfft(input, n, dim, norm)
......@@ -5,7 +5,6 @@ import torch
from torch import Tensor
import torchaudio
import torchaudio._internal.fft
__all__ = [
'get_mel_banks',
......@@ -290,7 +289,7 @@ def spectrogram(waveform: Tensor,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1, 2)
fft = torchaudio._internal.fft.rfft(strided_input)
fft = torch.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1)
......@@ -572,7 +571,7 @@ def fbank(waveform: Tensor,
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
# size (m, padded_window_size // 2 + 1)
spectrum = torchaudio._internal.fft.rfft(strided_input).abs()
spectrum = torch.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.)
......
......@@ -5,8 +5,6 @@ from typing import Optional
import torch
from torch import Tensor
import torchaudio._internal.fft
def _dB2Linear(x: float) -> float:
return math.exp(x * math.log(10) / 20.0)
......@@ -1301,7 +1299,7 @@ def _measure(
dftBuf[measure_len_ws:dft_len_ws].zero_()
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
_dftBuf = torchaudio._internal.fft.rfft(dftBuf)
_dftBuf = torch.fft.rfft(dftBuf)
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_()
......@@ -1338,7 +1336,7 @@ def _measure(
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_()
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
_cepstrum_Buf = torchaudio._internal.fft.rfft(_cepstrum_Buf)
_cepstrum_Buf = torch.fft.rfft(_cepstrum_Buf)
result: float = float(
torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment