Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
48d2b572
Unverified
Commit
48d2b572
authored
Nov 05, 2020
by
moto
Committed by
GitHub
Nov 05, 2020
Browse files
Migrate torch.rfft to torch.fft.rfft and cfloat (#941)
parent
b7c17f80
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
17 deletions
+42
-17
torchaudio/_internal/fft.py
torchaudio/_internal/fft.py
+27
-0
torchaudio/compliance/kaldi.py
torchaudio/compliance/kaldi.py
+10
-10
torchaudio/functional.py
torchaudio/functional.py
+5
-7
No files found.
torchaudio/_internal/fft.py
0 → 100644
View file @
48d2b572
"""Compatibility module for fft-related functions
In PyTorch 1.7, the new `torch.fft` module was introduced.
To use this new module, one has to explicitly import `torch.fft`. however this will change
the reference `torch.fft` is pointing from function to module.
And this change takes effect not only in the client code but also in already-imported libraries too.
Similarly, if a library does the explicit import, the rest of the application code must use the
`torch.fft.fft` function.
For this reason, to migrate the deprecated functions of fft-family, we need to use the new
implementation under `torch.fft` without explicitly importing `torch.fft` module.
This module provides a simple interface for the migration, abstracting away
the access to the underlying C functions.
Once the deprecated functions are removed from PyTorch and `torch.fft` starts to always represent
the new module, we can get rid of this module and call functions under `torch.fft` directly.
"""
from
typing
import
Optional
import
torch
def
rfft
(
input
:
torch
.
Tensor
,
n
:
Optional
[
int
]
=
None
,
dim
:
int
=
-
1
,
norm
:
Optional
[
str
]
=
None
)
->
torch
.
Tensor
:
# see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft
return
torch
.
_C
.
_fft
.
fft_rfft
(
input
,
n
,
dim
,
norm
)
torchaudio/compliance/kaldi.py
View file @
48d2b572
...
@@ -2,9 +2,11 @@ from typing import Tuple
...
@@ -2,9 +2,11 @@ from typing import Tuple
import
math
import
math
import
torch
import
torch
import
torchaudio
from
torch
import
Tensor
from
torch
import
Tensor
import
torchaudio
import
torchaudio._internal.fft
__all__
=
[
__all__
=
[
'get_mel_banks'
,
'get_mel_banks'
,
'inverse_mel_scale'
,
'inverse_mel_scale'
,
...
@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor,
...
@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor,
snip_edges
,
raw_energy
,
energy_floor
,
dither
,
remove_dc_offset
,
preemphasis_coefficient
)
snip_edges
,
raw_energy
,
energy_floor
,
dither
,
remove_dc_offset
,
preemphasis_coefficient
)
# size (m, padded_window_size // 2 + 1, 2)
# size (m, padded_window_size // 2 + 1, 2)
fft
=
torch
.
rfft
(
strided_input
,
1
,
normalized
=
False
,
onesided
=
True
)
fft
=
torch
audio
.
_internal
.
fft
.
rfft
(
strided_input
)
# Convert the FFT into a power spectrum
# Convert the FFT into a power spectrum
power_spectrum
=
torch
.
max
(
fft
.
pow
(
2
).
sum
(
2
),
epsilon
).
log
()
# size (m, padded_window_size // 2 + 1)
power_spectrum
=
torch
.
max
(
fft
.
abs
().
pow
(
2.
),
epsilon
).
log
()
# size (m, padded_window_size // 2 + 1)
power_spectrum
[:,
0
]
=
signal_log_energy
power_spectrum
[:,
0
]
=
signal_log_energy
power_spectrum
=
_subtract_column_mean
(
power_spectrum
,
subtract_mean
)
power_spectrum
=
_subtract_column_mean
(
power_spectrum
,
subtract_mean
)
...
@@ -570,12 +572,10 @@ def fbank(waveform: Tensor,
...
@@ -570,12 +572,10 @@ def fbank(waveform: Tensor,
waveform
,
padded_window_size
,
window_size
,
window_shift
,
window_type
,
blackman_coeff
,
waveform
,
padded_window_size
,
window_size
,
window_shift
,
window_type
,
blackman_coeff
,
snip_edges
,
raw_energy
,
energy_floor
,
dither
,
remove_dc_offset
,
preemphasis_coefficient
)
snip_edges
,
raw_energy
,
energy_floor
,
dither
,
remove_dc_offset
,
preemphasis_coefficient
)
# size (m, padded_window_size // 2 + 1, 2)
# size (m, padded_window_size // 2 + 1)
fft
=
torch
.
rfft
(
strided_input
,
1
,
normalized
=
False
,
onesided
=
True
)
spectrum
=
torchaudio
.
_internal
.
fft
.
rfft
(
strided_input
).
abs
()
if
use_power
:
power_spectrum
=
fft
.
pow
(
2
).
sum
(
2
)
# size (m, padded_window_size // 2 + 1)
spectrum
=
spectrum
.
pow
(
2.
)
if
not
use_power
:
power_spectrum
=
power_spectrum
.
pow
(
0.5
)
# size (num_mel_bins, padded_window_size // 2)
# size (num_mel_bins, padded_window_size // 2)
mel_energies
,
_
=
get_mel_banks
(
num_mel_bins
,
padded_window_size
,
sample_frequency
,
mel_energies
,
_
=
get_mel_banks
(
num_mel_bins
,
padded_window_size
,
sample_frequency
,
...
@@ -586,7 +586,7 @@ def fbank(waveform: Tensor,
...
@@ -586,7 +586,7 @@ def fbank(waveform: Tensor,
mel_energies
=
torch
.
nn
.
functional
.
pad
(
mel_energies
,
(
0
,
1
),
mode
=
'constant'
,
value
=
0
)
mel_energies
=
torch
.
nn
.
functional
.
pad
(
mel_energies
,
(
0
,
1
),
mode
=
'constant'
,
value
=
0
)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies
=
torch
.
mm
(
power_
spectrum
,
mel_energies
.
T
)
mel_energies
=
torch
.
mm
(
spectrum
,
mel_energies
.
T
)
if
use_log_fbank
:
if
use_log_fbank
:
# avoid log of zero (which should be prevented anyway by dithering)
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies
=
torch
.
max
(
mel_energies
,
_get_epsilon
(
device
,
dtype
)).
log
()
mel_energies
=
torch
.
max
(
mel_energies
,
_get_epsilon
(
device
,
dtype
)).
log
()
...
...
torchaudio/functional.py
View file @
48d2b572
...
@@ -6,6 +6,7 @@ import warnings
...
@@ -6,6 +6,7 @@ import warnings
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torchaudio._internal.fft
__all__
=
[
__all__
=
[
"spectrogram"
,
"spectrogram"
,
...
@@ -2073,7 +2074,7 @@ def _measure(
...
@@ -2073,7 +2074,7 @@ def _measure(
dftBuf
[
measure_len_ws
:
dft_len_ws
].
zero_
()
dftBuf
[
measure_len_ws
:
dft_len_ws
].
zero_
()
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
_dftBuf
=
torch
.
rfft
(
dftBuf
,
1
)
_dftBuf
=
torch
audio
.
_internal
.
fft
.
rfft
(
dftBuf
)
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf
[:
spectrum_start
].
zero_
()
_dftBuf
[:
spectrum_start
].
zero_
()
...
@@ -2082,7 +2083,7 @@ def _measure(
...
@@ -2082,7 +2083,7 @@ def _measure(
if
boot_count
>=
0
\
if
boot_count
>=
0
\
else
measure_smooth_time_mult
else
measure_smooth_time_mult
_d
=
complex_norm
(
_dftBuf
[
spectrum_start
:
spectrum_end
])
_d
=
_dftBuf
[
spectrum_start
:
spectrum_end
]
.
abs
(
)
spectrum
[
spectrum_start
:
spectrum_end
].
mul_
(
mult
).
add_
(
_d
*
(
1
-
mult
))
spectrum
[
spectrum_start
:
spectrum_end
].
mul_
(
mult
).
add_
(
_d
*
(
1
-
mult
))
_d
=
spectrum
[
spectrum_start
:
spectrum_end
]
**
2
_d
=
spectrum
[
spectrum_start
:
spectrum_end
]
**
2
...
@@ -2106,12 +2107,9 @@ def _measure(
...
@@ -2106,12 +2107,9 @@ def _measure(
_cepstrum_Buf
[
spectrum_end
:
dft_len_ws
>>
1
].
zero_
()
_cepstrum_Buf
[
spectrum_end
:
dft_len_ws
>>
1
].
zero_
()
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
_cepstrum_Buf
=
torch
.
rfft
(
_cepstrum_Buf
,
1
)
_cepstrum_Buf
=
torch
audio
.
_internal
.
fft
.
rfft
(
_cepstrum_Buf
)
result
:
float
=
float
(
torch
.
sum
(
result
:
float
=
float
(
torch
.
sum
(
_cepstrum_Buf
[
cepstrum_start
:
cepstrum_end
].
abs
().
pow
(
2
)))
complex_norm
(
_cepstrum_Buf
[
cepstrum_start
:
cepstrum_end
],
power
=
2.0
)))
result
=
\
result
=
\
math
.
log
(
result
/
(
cepstrum_end
-
cepstrum_start
))
\
math
.
log
(
result
/
(
cepstrum_end
-
cepstrum_start
))
\
if
result
>
0
\
if
result
>
0
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment