Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
86370639
Unverified
Commit
86370639
authored
Jul 29, 2021
by
Joel Frank
Committed by
GitHub
Jul 29, 2021
Browse files
Add LFCC feature to transforms (#1611)
Summary: - Add linear_fbank method - Add LFCC in transforms
parent
108a32d9
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
271 additions
and
26 deletions
+271
-26
docs/source/functional.rst
docs/source/functional.rst
+5
-0
docs/source/transforms.rst
docs/source/transforms.rst
+7
-0
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+15
-0
test/torchaudio_unittest/transforms/autograd_test_impl.py
test/torchaudio_unittest/transforms/autograd_test_impl.py
+7
-0
test/torchaudio_unittest/transforms/batch_consistency_test.py
.../torchaudio_unittest/transforms/batch_consistency_test.py
+10
-0
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+4
-0
test/torchaudio_unittest/transforms/transforms_test.py
test/torchaudio_unittest/transforms/transforms_test.py
+59
-0
torchaudio/functional/__init__.py
torchaudio/functional/__init__.py
+2
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+64
-9
torchaudio/transforms.py
torchaudio/transforms.py
+98
-17
No files found.
docs/source/functional.rst
View file @
86370639
...
...
@@ -26,6 +26,11 @@ create_fb_matrix
.. autofunction:: create_fb_matrix
linear_fbanks
-------------
.. autofunction:: linear_fbanks
create_dct
----------
...
...
docs/source/transforms.rst
View file @
86370639
...
...
@@ -59,6 +59,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward
:hidden:`LFCC`
~~~~~~~~~~~~~~
.. autoclass:: LFCC
.. automethod:: forward
:hidden:`MuLawEncoding`
~~~~~~~~~~~~~~~~~~~~~~~
...
...
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
86370639
...
...
@@ -132,6 +132,21 @@ class Functional(TempDirMixin, TestBaseMixin):
dummy
=
torch
.
zeros
(
1
,
1
)
self
.
_assert_consistency
(
func
,
dummy
)
def
test_linear_fbanks
(
self
):
if
self
.
device
!=
torch
.
device
(
'cpu'
):
raise
unittest
.
SkipTest
(
'No need to perform test on device other than CPU'
)
def
func
(
_
):
n_stft
=
100
f_min
=
0.0
f_max
=
20.0
n_filter
=
10
sample_rate
=
16000
return
F
.
linear_fbanks
(
n_stft
,
f_min
,
f_max
,
n_filter
,
sample_rate
)
dummy
=
torch
.
zeros
(
1
,
1
)
self
.
_assert_consistency
(
func
,
dummy
)
def
test_amplitude_to_DB
(
self
):
def
func
(
tensor
):
multiplier
=
10.0
...
...
test/torchaudio_unittest/transforms/autograd_test_impl.py
View file @
86370639
...
...
@@ -108,6 +108,13 @@ class AutogradTestMixin(TestBaseMixin):
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
2
)
self
.
assert_grad
(
transform
,
[
waveform
])
@
parameterized
.
expand
([(
False
,
),
(
True
,
)])
def
test_lfcc
(
self
,
log_lf
):
sample_rate
=
8000
transform
=
T
.
LFCC
(
sample_rate
=
sample_rate
,
log_lf
=
log_lf
)
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
2
)
self
.
assert_grad
(
transform
,
[
waveform
])
def
test_compute_deltas
(
self
):
transform
=
T
.
ComputeDeltas
()
spec
=
torch
.
rand
(
10
,
20
)
...
...
test/torchaudio_unittest/transforms/batch_consistency_test.py
View file @
86370639
...
...
@@ -127,6 +127,16 @@ class TestTransforms(common_utils.TorchaudioTestCase):
computed
=
torchaudio
.
transforms
.
MFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
def
test_batch_lfcc
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
LFCC
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
LFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_batch_TimeStretch
(
self
,
test_pseudo_complex
):
rate
=
2
...
...
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
86370639
...
...
@@ -71,6 +71,10 @@ class Transforms(TempDirMixin, TestBaseMixin):
tensor
=
torch
.
rand
((
1
,
1000
))
self
.
_assert_consistency
(
T
.
MFCC
(),
tensor
)
def
test_LFCC
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
self
.
_assert_consistency
(
T
.
LFCC
(),
tensor
)
def
test_Resample
(
self
):
sr1
,
sr2
=
16000
,
8000
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
sr1
)
...
...
test/torchaudio_unittest/transforms/transforms_test.py
View file @
86370639
...
...
@@ -180,6 +180,65 @@ class Tester(common_utils.TorchaudioTestCase):
self
.
assertEqual
(
torch_mfcc_norm_none
,
norm_check
)
def
test_lfcc_defaults
(
self
):
"""Check default settings for LFCC transform.
"""
sample_rate
=
16000
audio
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
n_lfcc
=
40
n_filter
=
128
lfcc_transform
=
torchaudio
.
transforms
.
LFCC
(
sample_rate
=
sample_rate
,
n_filter
=
n_filter
,
n_lfcc
=
n_lfcc
,
norm
=
'ortho'
)
torch_lfcc
=
lfcc_transform
(
audio
)
# (1, 40, 81)
self
.
assertEqual
(
torch_lfcc
.
dim
(),
3
)
self
.
assertEqual
(
torch_lfcc
.
shape
[
1
],
n_lfcc
)
self
.
assertEqual
(
torch_lfcc
.
shape
[
2
],
81
)
def
test_lfcc_arg_passthrough
(
self
):
"""Check if kwargs get correctly passed to the underlying Spectrogram transform.
"""
sample_rate
=
16000
audio
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
n_lfcc
=
40
n_filter
=
128
speckwargs
=
{
'win_length'
:
200
}
lfcc_transform
=
torchaudio
.
transforms
.
LFCC
(
sample_rate
=
sample_rate
,
n_filter
=
n_filter
,
n_lfcc
=
n_lfcc
,
norm
=
'ortho'
,
speckwargs
=
speckwargs
)
torch_lfcc
=
lfcc_transform
(
audio
)
# (1, 40, 161)
self
.
assertEqual
(
torch_lfcc
.
shape
[
2
],
161
)
def
test_lfcc_norms
(
self
):
"""Check if LFCC-DCT norm works correctly.
"""
sample_rate
=
16000
audio
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
n_lfcc
=
40
n_filter
=
128
lfcc_transform
=
torchaudio
.
transforms
.
LFCC
(
sample_rate
=
sample_rate
,
n_filter
=
n_filter
,
n_lfcc
=
n_lfcc
,
norm
=
'ortho'
)
lfcc_transform_norm_none
=
torchaudio
.
transforms
.
LFCC
(
sample_rate
=
sample_rate
,
n_filter
=
n_filter
,
n_lfcc
=
n_lfcc
,
norm
=
None
)
torch_lfcc_norm_none
=
lfcc_transform_norm_none
(
audio
)
# (1, 40, 161)
norm_check
=
lfcc_transform
(
audio
)
# (1, 40, 161)
norm_check
[:,
0
,
:]
*=
math
.
sqrt
(
n_filter
)
*
2
norm_check
[:,
1
:,
:]
*=
math
.
sqrt
(
n_filter
/
2
)
*
2
self
.
assertEqual
(
torch_lfcc_norm_none
,
norm_check
)
def
test_resample_size
(
self
):
input_path
=
common_utils
.
get_asset_path
(
'sinewave.wav'
)
waveform
,
sample_rate
=
common_utils
.
load_wav
(
input_path
)
...
...
torchaudio/functional/__init__.py
View file @
86370639
...
...
@@ -6,6 +6,7 @@ from .functional import (
compute_kaldi_pitch
,
create_dct
,
create_fb_matrix
,
linear_fbanks
,
DB_to_amplitude
,
detect_pitch_frequency
,
griffinlim
,
...
...
@@ -55,6 +56,7 @@ __all__ = [
'compute_kaldi_pitch'
,
'create_dct'
,
'create_fb_matrix'
,
'linear_fbanks'
,
'DB_to_amplitude'
,
'detect_pitch_frequency'
,
'griffinlim'
,
...
...
torchaudio/functional/functional.py
View file @
86370639
...
...
@@ -19,6 +19,7 @@ __all__ = [
"compute_deltas"
,
"compute_kaldi_pitch"
,
"create_fb_matrix"
,
"linear_fbanks"
,
"create_dct"
,
"compute_deltas"
,
"detect_pitch_frequency"
,
...
...
@@ -376,6 +377,32 @@ def _mel_to_hz(mels: Tensor, mel_scale: str = "htk") -> Tensor:
return
freqs
def
_create_triangular_filterbank
(
all_freqs
:
Tensor
,
f_pts
:
Tensor
,
)
->
Tensor
:
"""Create a triangular filter bank.
Args:
all_freqs (Tensor): STFT freq points of size (`n_freqs`).
f_pts (Tensor): Filter mid points of size (`n_filter`).
Returns:
fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
"""
# Adopted from Librosa
# calculate the difference between each filter mid point and each stft freq point in hertz
f_diff
=
f_pts
[
1
:]
-
f_pts
[:
-
1
]
# (n_filter + 1)
slopes
=
f_pts
.
unsqueeze
(
0
)
-
all_freqs
.
unsqueeze
(
1
)
# (n_freqs, n_filter + 2)
# create overlapping triangles
zero
=
torch
.
zeros
(
1
)
down_slopes
=
(
-
1.0
*
slopes
[:,
:
-
2
])
/
f_diff
[:
-
1
]
# (n_freqs, n_filter)
up_slopes
=
slopes
[:,
2
:]
/
f_diff
[
1
:]
# (n_freqs, n_filter)
fb
=
torch
.
max
(
zero
,
torch
.
min
(
down_slopes
,
up_slopes
))
return
fb
def
create_fb_matrix
(
n_freqs
:
int
,
f_min
:
float
,
...
...
@@ -409,7 +436,6 @@ def create_fb_matrix(
raise
ValueError
(
"norm must be one of None or 'slaney'"
)
# freq bins
# Equivalent filterbank construction by Librosa
all_freqs
=
torch
.
linspace
(
0
,
sample_rate
//
2
,
n_freqs
)
# calculate mel freq bins
...
...
@@ -419,14 +445,8 @@ def create_fb_matrix(
m_pts
=
torch
.
linspace
(
m_min
,
m_max
,
n_mels
+
2
)
f_pts
=
_mel_to_hz
(
m_pts
,
mel_scale
=
mel_scale
)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff
=
f_pts
[
1
:]
-
f_pts
[:
-
1
]
# (n_mels + 1)
slopes
=
f_pts
.
unsqueeze
(
0
)
-
all_freqs
.
unsqueeze
(
1
)
# (n_freqs, n_mels + 2)
# create overlapping triangles
zero
=
torch
.
zeros
(
1
)
down_slopes
=
(
-
1.0
*
slopes
[:,
:
-
2
])
/
f_diff
[:
-
1
]
# (n_freqs, n_mels)
up_slopes
=
slopes
[:,
2
:]
/
f_diff
[
1
:]
# (n_freqs, n_mels)
fb
=
torch
.
max
(
zero
,
torch
.
min
(
down_slopes
,
up_slopes
))
# create filterbank
fb
=
_create_triangular_filterbank
(
all_freqs
,
f_pts
)
if
norm
is
not
None
and
norm
==
"slaney"
:
# Slaney-style mel is scaled to be approx constant energy per channel
...
...
@@ -443,6 +463,41 @@ def create_fb_matrix(
return
fb
def
linear_fbanks
(
n_freqs
:
int
,
f_min
:
float
,
f_max
:
float
,
n_filter
:
int
,
sample_rate
:
int
,
)
->
Tensor
:
r
"""Creates a linear triangular filterbank.
Args:
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency (Hz)
f_max (float): Maximum frequency (Hz)
n_filter (int): Number of (linear) triangular filter
sample_rate (int): Sample rate of the audio waveform
Returns:
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_filter``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
``A * linear_fbanks(A.size(-1), ...)``.
"""
# freq bins
all_freqs
=
torch
.
linspace
(
0
,
sample_rate
//
2
,
n_freqs
)
# filter mid-points
f_pts
=
torch
.
linspace
(
f_min
,
f_max
,
n_filter
+
2
)
# create filterbank
fb
=
_create_triangular_filterbank
(
all_freqs
,
f_pts
)
return
fb
def
create_dct
(
n_mfcc
:
int
,
n_mels
:
int
,
...
...
torchaudio/transforms.py
View file @
86370639
...
...
@@ -21,6 +21,7 @@ __all__ = [
'InverseMelScale'
,
'MelSpectrogram'
,
'MFCC'
,
'LFCC'
,
'MuLawEncoding'
,
'MuLawDecoding'
,
'Resample'
,
...
...
@@ -282,16 +283,8 @@ class MelScale(torch.nn.Module):
Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# pack batch
shape
=
specgram
.
size
()
specgram
=
specgram
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...)
mel_specgram
=
torch
.
matmul
(
specgram
.
transpose
(
1
,
2
),
self
.
fb
).
transpose
(
1
,
2
)
# unpack batch
mel_specgram
=
mel_specgram
.
reshape
(
shape
[:
-
2
]
+
mel_specgram
.
shape
[
-
2
:])
# (..., time, freq) dot (freq, n_mels) -> (..., n_mels, time)
mel_specgram
=
torch
.
matmul
(
specgram
.
transpose
(
-
1
,
-
2
),
self
.
fb
).
transpose
(
-
1
,
-
2
)
return
mel_specgram
...
...
@@ -532,10 +525,8 @@ class MFCC(torch.nn.Module):
self
.
top_db
=
80.0
self
.
amplitude_to_DB
=
AmplitudeToDB
(
'power'
,
self
.
top_db
)
if
melkwargs
is
not
None
:
self
.
MelSpectrogram
=
MelSpectrogram
(
sample_rate
=
self
.
sample_rate
,
**
melkwargs
)
else
:
self
.
MelSpectrogram
=
MelSpectrogram
(
sample_rate
=
self
.
sample_rate
)
melkwargs
=
melkwargs
or
{}
self
.
MelSpectrogram
=
MelSpectrogram
(
sample_rate
=
self
.
sample_rate
,
**
melkwargs
)
if
self
.
n_mfcc
>
self
.
MelSpectrogram
.
n_mels
:
raise
ValueError
(
'Cannot select more MFCC coefficients than # mel bins'
)
...
...
@@ -558,12 +549,102 @@ class MFCC(torch.nn.Module):
else
:
mel_specgram
=
self
.
amplitude_to_DB
(
mel_specgram
)
# (..., channel, n_mels, time).transpose(...) dot (n_mels, n_mfcc)
# -> (..., channel, time, n_mfcc).transpose(...)
mfcc
=
torch
.
matmul
(
mel_specgram
.
transpose
(
-
2
,
-
1
),
self
.
dct_mat
).
transpose
(
-
2
,
-
1
)
# (..., time, n_mels) dot (n_mels, n_mfcc) -> (..., n_nfcc, time)
mfcc
=
torch
.
matmul
(
mel_specgram
.
transpose
(
-
1
,
-
2
),
self
.
dct_mat
).
transpose
(
-
1
,
-
2
)
return
mfcc
class
LFCC
(
torch
.
nn
.
Module
):
r
"""Create the linear-frequency cepstrum coefficients from an audio signal.
By default, this calculates the LFCC on the DB-scaled linear filtered spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
n_filter (int, optional): Number of linear filters to apply. (Default: ``128``)
n_lfcc (int, optional): Number of lfc coefficients to retain. (Default: ``40``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``None``)
dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``)
log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``)
speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``)
"""
__constants__
=
[
'sample_rate'
,
'n_filter'
,
'n_lfcc'
,
'dct_type'
,
'top_db'
,
'log_lf'
]
def
__init__
(
self
,
sample_rate
:
int
=
16000
,
n_filter
:
int
=
128
,
f_min
:
float
=
0.
,
f_max
:
Optional
[
float
]
=
None
,
n_lfcc
:
int
=
40
,
dct_type
:
int
=
2
,
norm
:
str
=
'ortho'
,
log_lf
:
bool
=
False
,
speckwargs
:
Optional
[
dict
]
=
None
)
->
None
:
super
(
LFCC
,
self
).
__init__
()
supported_dct_types
=
[
2
]
if
dct_type
not
in
supported_dct_types
:
raise
ValueError
(
'DCT type not supported: {}'
.
format
(
dct_type
))
self
.
sample_rate
=
sample_rate
self
.
f_min
=
f_min
self
.
f_max
=
f_max
if
f_max
is
not
None
else
float
(
sample_rate
//
2
)
self
.
n_filter
=
n_filter
self
.
n_lfcc
=
n_lfcc
self
.
dct_type
=
dct_type
self
.
norm
=
norm
self
.
top_db
=
80.0
self
.
amplitude_to_DB
=
AmplitudeToDB
(
'power'
,
self
.
top_db
)
speckwargs
=
speckwargs
or
{}
self
.
Spectrogram
=
Spectrogram
(
**
speckwargs
)
if
self
.
n_lfcc
>
self
.
Spectrogram
.
n_fft
:
raise
ValueError
(
'Cannot select more LFCC coefficients than # fft bins'
)
filter_mat
=
F
.
linear_fbanks
(
n_freqs
=
self
.
Spectrogram
.
n_fft
//
2
+
1
,
f_min
=
self
.
f_min
,
f_max
=
self
.
f_max
,
n_filter
=
self
.
n_filter
,
sample_rate
=
self
.
sample_rate
,
)
self
.
register_buffer
(
"filter_mat"
,
filter_mat
)
dct_mat
=
F
.
create_dct
(
self
.
n_lfcc
,
self
.
n_filter
,
self
.
norm
)
self
.
register_buffer
(
'dct_mat'
,
dct_mat
)
self
.
log_lf
=
log_lf
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Linear Frequency Cepstral Coefficients of size (..., ``n_lfcc``, time).
"""
specgram
=
self
.
Spectrogram
(
waveform
)
# (..., time, freq) dot (freq, n_filter) -> (..., n_filter, time)
specgram
=
torch
.
matmul
(
specgram
.
transpose
(
-
1
,
-
2
),
self
.
filter_mat
).
transpose
(
-
1
,
-
2
)
if
self
.
log_lf
:
log_offset
=
1e-6
specgram
=
torch
.
log
(
specgram
+
log_offset
)
else
:
specgram
=
self
.
amplitude_to_DB
(
specgram
)
# (..., time, n_filter) dot (n_filter, n_lfcc) -> (..., n_lfcc, time)
lfcc
=
torch
.
matmul
(
specgram
.
transpose
(
-
1
,
-
2
),
self
.
dct_mat
).
transpose
(
-
1
,
-
2
)
return
lfcc
class
MuLawEncoding
(
torch
.
nn
.
Module
):
r
"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment