Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
27031755
Unverified
Commit
27031755
authored
Jan 25, 2021
by
Nicolas Hug
Committed by
GitHub
Jan 25, 2021
Browse files
Add SpectralCentroid transform (#1167)
parent
5547f204
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
132 additions
and
0 deletions
+132
-0
test/torchaudio_unittest/batch_consistency_test.py
test/torchaudio_unittest/batch_consistency_test.py
+11
-0
test/torchaudio_unittest/librosa_compatibility_test.py
test/torchaudio_unittest/librosa_compatibility_test.py
+13
-0
test/torchaudio_unittest/torchscript_consistency_impl.py
test/torchaudio_unittest/torchscript_consistency_impl.py
+19
-0
torchaudio/functional/__init__.py
torchaudio/functional/__init__.py
+1
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+36
-0
torchaudio/transforms.py
torchaudio/transforms.py
+52
-0
No files found.
test/torchaudio_unittest/batch_consistency_test.py
View file @
27031755
...
@@ -283,3 +283,14 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -283,3 +283,14 @@ class TestTransforms(common_utils.TorchaudioTestCase):
# Batch then transform
# Batch then transform
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_spectral_centroid
(
self
):
sample_rate
=
44100
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
test/torchaudio_unittest/librosa_compatibility_test.py
View file @
27031755
...
@@ -231,6 +231,19 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -231,6 +231,19 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self
.
assertEqual
(
self
.
assertEqual
(
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
self
.
assert_compatibilities_spectral_centroid
(
sample_rate
,
n_fft
,
hop_length
,
sound
,
sound_librosa
)
def
assert_compatibilities_spectral_centroid
(
self
,
sample_rate
,
n_fft
,
hop_length
,
sound
,
sound_librosa
):
spect_centroid
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
)
out_torch
=
spect_centroid
(
sound
).
squeeze
().
cpu
()
out_librosa
=
librosa
.
feature
.
spectral_centroid
(
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
)
out_librosa
=
torch
.
from_numpy
(
out_librosa
)[
0
]
self
.
assertEqual
(
out_torch
.
type
(
out_librosa
.
dtype
),
out_librosa
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_basics1
(
self
):
def
test_basics1
(
self
):
kwargs
=
{
kwargs
=
{
'n_fft'
:
400
,
'n_fft'
:
400
,
...
...
test/torchaudio_unittest/torchscript_consistency_impl.py
View file @
27031755
...
@@ -535,6 +535,20 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -535,6 +535,20 @@ class Functional(common_utils.TestBaseMixin):
self
.
_assert_consistency
(
func
,
waveform
)
self
.
_assert_consistency
(
func
,
waveform
)
def
test_spectral_centroid
(
self
):
def
func
(
tensor
):
sample_rate
=
44100
n_fft
=
400
ws
=
400
hop
=
200
pad
=
0
window
=
torch
.
hann_window
(
ws
,
device
=
tensor
.
device
,
dtype
=
tensor
.
dtype
)
return
F
.
spectral_centroid
(
tensor
,
sample_rate
,
pad
,
window
,
n_fft
,
hop
,
ws
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
tensor
)
class
Transforms
(
common_utils
.
TestBaseMixin
):
class
Transforms
(
common_utils
.
TestBaseMixin
):
"""Implements test for Transforms that are performed for different devices"""
"""Implements test for Transforms that are performed for different devices"""
...
@@ -624,3 +638,8 @@ class Transforms(common_utils.TestBaseMixin):
...
@@ -624,3 +638,8 @@ class Transforms(common_utils.TestBaseMixin):
filepath
=
common_utils
.
get_asset_path
(
"vad-go-mono-32000.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"vad-go-mono-32000.wav"
)
waveform
,
sample_rate
=
common_utils
.
load_wav
(
filepath
)
waveform
,
sample_rate
=
common_utils
.
load_wav
(
filepath
)
self
.
_assert_consistency
(
T
.
Vad
(
sample_rate
=
sample_rate
),
waveform
)
self
.
_assert_consistency
(
T
.
Vad
(
sample_rate
=
sample_rate
),
waveform
)
def
test_SpectralCentroid
(
self
):
sample_rate
=
44100
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
self
.
_assert_consistency
(
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
),
waveform
)
torchaudio/functional/__init__.py
View file @
27031755
...
@@ -16,6 +16,7 @@ from .functional import (
...
@@ -16,6 +16,7 @@ from .functional import (
phase_vocoder
,
phase_vocoder
,
sliding_window_cmn
,
sliding_window_cmn
,
spectrogram
,
spectrogram
,
spectral_centroid
,
)
)
from
.filtering
import
(
from
.filtering
import
(
allpass_biquad
,
allpass_biquad
,
...
...
torchaudio/functional/functional.py
View file @
27031755
...
@@ -27,6 +27,7 @@ __all__ = [
...
@@ -27,6 +27,7 @@ __all__ = [
'mask_along_axis'
,
'mask_along_axis'
,
'mask_along_axis_iid'
,
'mask_along_axis_iid'
,
'sliding_window_cmn'
,
'sliding_window_cmn'
,
"spectral_centroid"
,
]
]
...
@@ -935,3 +936,38 @@ def sliding_window_cmn(
...
@@ -935,3 +936,38 @@ def sliding_window_cmn(
if
len
(
input_shape
)
==
2
:
if
len
(
input_shape
)
==
2
:
cmn_waveform
=
cmn_waveform
.
squeeze
(
0
)
cmn_waveform
=
cmn_waveform
.
squeeze
(
0
)
return
cmn_waveform
return
cmn_waveform
def
spectral_centroid
(
waveform
:
Tensor
,
sample_rate
:
int
,
pad
:
int
,
window
:
Tensor
,
n_fft
:
int
,
hop_length
:
int
,
win_length
:
int
,
)
->
Tensor
:
r
"""
Compute the spectral centroid for each channel along the time axis.
The spectral centroid is defined as the weighted average of the
frequency values, weighted by their magnitude.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
sample_rate (int): Sample rate of the audio waveform
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
Returns:
Tensor: Dimension (..., time)
"""
specgram
=
spectrogram
(
waveform
,
pad
=
pad
,
window
=
window
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
power
=
1.
,
normalized
=
False
)
freqs
=
torch
.
linspace
(
0
,
sample_rate
//
2
,
steps
=
1
+
n_fft
//
2
,
device
=
specgram
.
device
).
reshape
((
-
1
,
1
))
freq_dim
=
-
2
return
(
freqs
*
specgram
).
sum
(
dim
=
freq_dim
)
/
specgram
.
sum
(
dim
=
freq_dim
)
torchaudio/transforms.py
View file @
27031755
...
@@ -28,6 +28,7 @@ __all__ = [
...
@@ -28,6 +28,7 @@ __all__ = [
'TimeMasking'
,
'TimeMasking'
,
'SlidingWindowCmn'
,
'SlidingWindowCmn'
,
'Vad'
,
'Vad'
,
'SpectralCentroid'
,
]
]
...
@@ -1037,3 +1038,54 @@ class Vad(torch.nn.Module):
...
@@ -1037,3 +1038,54 @@ class Vad(torch.nn.Module):
hp_lifter_freq
=
self
.
hp_lifter_freq
,
hp_lifter_freq
=
self
.
hp_lifter_freq
,
lp_lifter_freq
=
self
.
lp_lifter_freq
,
lp_lifter_freq
=
self
.
lp_lifter_freq
,
)
)
class
SpectralCentroid
(
torch
.
nn
.
Module
):
r
"""Compute the spectral centroid for each channel along the time axis.
The spectral centroid is defined as the weighted average of the
frequency values, weighted by their magnitude.
Args:
sample_rate (int): Sample rate of audio signal.
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window(Tensor, optional): A window tensor that is applied/multiplied to each frame.
(Default: ``torch.hann_window(win_length)``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> spectral_centroid = transforms.SpectralCentroid(sample_rate)(waveform) # (channel, time)
"""
__constants__
=
[
'sample_rate'
,
'n_fft'
,
'win_length'
,
'hop_length'
,
'pad'
]
def
__init__
(
self
,
sample_rate
:
int
,
n_fft
:
int
=
400
,
win_length
:
Optional
[
int
]
=
None
,
hop_length
:
Optional
[
int
]
=
None
,
pad
:
int
=
0
,
window
:
Optional
[
Tensor
]
=
None
)
->
None
:
super
(
SpectralCentroid
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
win_length
=
win_length
if
win_length
is
not
None
else
n_fft
self
.
hop_length
=
hop_length
if
hop_length
is
not
None
else
self
.
win_length
//
2
if
window
is
None
:
window
=
torch
.
hann_window
(
self
.
win_length
)
self
.
register_buffer
(
'window'
,
window
)
self
.
pad
=
pad
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Spectral Centroid of size (..., time).
"""
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment