Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
27031755
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "6534efd6ca5955316f60fce7563e3bcd97cca583"
Unverified
Commit
27031755
authored
Jan 25, 2021
by
Nicolas Hug
Committed by
GitHub
Jan 25, 2021
Browse files
Add SpectralCentroid transform (#1167)
parent
5547f204
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
132 additions
and
0 deletions
+132
-0
test/torchaudio_unittest/batch_consistency_test.py
test/torchaudio_unittest/batch_consistency_test.py
+11
-0
test/torchaudio_unittest/librosa_compatibility_test.py
test/torchaudio_unittest/librosa_compatibility_test.py
+13
-0
test/torchaudio_unittest/torchscript_consistency_impl.py
test/torchaudio_unittest/torchscript_consistency_impl.py
+19
-0
torchaudio/functional/__init__.py
torchaudio/functional/__init__.py
+1
-0
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+36
-0
torchaudio/transforms.py
torchaudio/transforms.py
+52
-0
No files found.
test/torchaudio_unittest/batch_consistency_test.py
View file @
27031755
...
@@ -283,3 +283,14 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -283,3 +283,14 @@ class TestTransforms(common_utils.TorchaudioTestCase):
# Batch then transform
# Batch then transform
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_spectral_centroid
(
self
):
sample_rate
=
44100
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
test/torchaudio_unittest/librosa_compatibility_test.py
View file @
27031755
...
@@ -231,6 +231,19 @@ class TestTransforms(common_utils.TorchaudioTestCase):
...
@@ -231,6 +231,19 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self
.
assertEqual
(
self
.
assertEqual
(
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
self
.
assert_compatibilities_spectral_centroid
(
sample_rate
,
n_fft
,
hop_length
,
sound
,
sound_librosa
)
def
assert_compatibilities_spectral_centroid
(
self
,
sample_rate
,
n_fft
,
hop_length
,
sound
,
sound_librosa
):
spect_centroid
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
)
out_torch
=
spect_centroid
(
sound
).
squeeze
().
cpu
()
out_librosa
=
librosa
.
feature
.
spectral_centroid
(
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
)
out_librosa
=
torch
.
from_numpy
(
out_librosa
)[
0
]
self
.
assertEqual
(
out_torch
.
type
(
out_librosa
.
dtype
),
out_librosa
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_basics1
(
self
):
def
test_basics1
(
self
):
kwargs
=
{
kwargs
=
{
'n_fft'
:
400
,
'n_fft'
:
400
,
...
...
test/torchaudio_unittest/torchscript_consistency_impl.py
View file @
27031755
...
@@ -535,6 +535,20 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -535,6 +535,20 @@ class Functional(common_utils.TestBaseMixin):
self
.
_assert_consistency
(
func
,
waveform
)
self
.
_assert_consistency
(
func
,
waveform
)
def
test_spectral_centroid
(
self
):
def
func
(
tensor
):
sample_rate
=
44100
n_fft
=
400
ws
=
400
hop
=
200
pad
=
0
window
=
torch
.
hann_window
(
ws
,
device
=
tensor
.
device
,
dtype
=
tensor
.
dtype
)
return
F
.
spectral_centroid
(
tensor
,
sample_rate
,
pad
,
window
,
n_fft
,
hop
,
ws
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
tensor
)
class
Transforms
(
common_utils
.
TestBaseMixin
):
class
Transforms
(
common_utils
.
TestBaseMixin
):
"""Implements test for Transforms that are performed for different devices"""
"""Implements test for Transforms that are performed for different devices"""
...
@@ -624,3 +638,8 @@ class Transforms(common_utils.TestBaseMixin):
...
@@ -624,3 +638,8 @@ class Transforms(common_utils.TestBaseMixin):
filepath
=
common_utils
.
get_asset_path
(
"vad-go-mono-32000.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"vad-go-mono-32000.wav"
)
waveform
,
sample_rate
=
common_utils
.
load_wav
(
filepath
)
waveform
,
sample_rate
=
common_utils
.
load_wav
(
filepath
)
self
.
_assert_consistency
(
T
.
Vad
(
sample_rate
=
sample_rate
),
waveform
)
self
.
_assert_consistency
(
T
.
Vad
(
sample_rate
=
sample_rate
),
waveform
)
def
test_SpectralCentroid
(
self
):
sample_rate
=
44100
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
self
.
_assert_consistency
(
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
),
waveform
)
torchaudio/functional/__init__.py
View file @
27031755
...
@@ -16,6 +16,7 @@ from .functional import (
...
@@ -16,6 +16,7 @@ from .functional import (
phase_vocoder
,
phase_vocoder
,
sliding_window_cmn
,
sliding_window_cmn
,
spectrogram
,
spectrogram
,
spectral_centroid
,
)
)
from
.filtering
import
(
from
.filtering
import
(
allpass_biquad
,
allpass_biquad
,
...
...
torchaudio/functional/functional.py
View file @
27031755
...
@@ -27,6 +27,7 @@ __all__ = [
...
@@ -27,6 +27,7 @@ __all__ = [
'mask_along_axis'
,
'mask_along_axis'
,
'mask_along_axis_iid'
,
'mask_along_axis_iid'
,
'sliding_window_cmn'
,
'sliding_window_cmn'
,
"spectral_centroid"
,
]
]
...
@@ -935,3 +936,38 @@ def sliding_window_cmn(
...
@@ -935,3 +936,38 @@ def sliding_window_cmn(
if
len
(
input_shape
)
==
2
:
if
len
(
input_shape
)
==
2
:
cmn_waveform
=
cmn_waveform
.
squeeze
(
0
)
cmn_waveform
=
cmn_waveform
.
squeeze
(
0
)
return
cmn_waveform
return
cmn_waveform
def
spectral_centroid
(
waveform
:
Tensor
,
sample_rate
:
int
,
pad
:
int
,
window
:
Tensor
,
n_fft
:
int
,
hop_length
:
int
,
win_length
:
int
,
)
->
Tensor
:
r
"""
Compute the spectral centroid for each channel along the time axis.
The spectral centroid is defined as the weighted average of the
frequency values, weighted by their magnitude.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
sample_rate (int): Sample rate of the audio waveform
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
Returns:
Tensor: Dimension (..., time)
"""
specgram
=
spectrogram
(
waveform
,
pad
=
pad
,
window
=
window
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
power
=
1.
,
normalized
=
False
)
freqs
=
torch
.
linspace
(
0
,
sample_rate
//
2
,
steps
=
1
+
n_fft
//
2
,
device
=
specgram
.
device
).
reshape
((
-
1
,
1
))
freq_dim
=
-
2
return
(
freqs
*
specgram
).
sum
(
dim
=
freq_dim
)
/
specgram
.
sum
(
dim
=
freq_dim
)
torchaudio/transforms.py
View file @
27031755
...
@@ -28,6 +28,7 @@ __all__ = [
...
@@ -28,6 +28,7 @@ __all__ = [
'TimeMasking'
,
'TimeMasking'
,
'SlidingWindowCmn'
,
'SlidingWindowCmn'
,
'Vad'
,
'Vad'
,
'SpectralCentroid'
,
]
]
...
@@ -1037,3 +1038,54 @@ class Vad(torch.nn.Module):
...
@@ -1037,3 +1038,54 @@ class Vad(torch.nn.Module):
hp_lifter_freq
=
self
.
hp_lifter_freq
,
hp_lifter_freq
=
self
.
hp_lifter_freq
,
lp_lifter_freq
=
self
.
lp_lifter_freq
,
lp_lifter_freq
=
self
.
lp_lifter_freq
,
)
)
class
SpectralCentroid
(
torch
.
nn
.
Module
):
r
"""Compute the spectral centroid for each channel along the time axis.
The spectral centroid is defined as the weighted average of the
frequency values, weighted by their magnitude.
Args:
sample_rate (int): Sample rate of audio signal.
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window(Tensor, optional): A window tensor that is applied/multiplied to each frame.
(Default: ``torch.hann_window(win_length)``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> spectral_centroid = transforms.SpectralCentroid(sample_rate)(waveform) # (channel, time)
"""
__constants__
=
[
'sample_rate'
,
'n_fft'
,
'win_length'
,
'hop_length'
,
'pad'
]
def
__init__
(
self
,
sample_rate
:
int
,
n_fft
:
int
=
400
,
win_length
:
Optional
[
int
]
=
None
,
hop_length
:
Optional
[
int
]
=
None
,
pad
:
int
=
0
,
window
:
Optional
[
Tensor
]
=
None
)
->
None
:
super
(
SpectralCentroid
,
self
).
__init__
()
self
.
sample_rate
=
sample_rate
self
.
n_fft
=
n_fft
self
.
win_length
=
win_length
if
win_length
is
not
None
else
n_fft
self
.
hop_length
=
hop_length
if
hop_length
is
not
None
else
self
.
win_length
//
2
if
window
is
None
:
window
=
torch
.
hann_window
(
self
.
win_length
)
self
.
register_buffer
(
'window'
,
window
)
self
.
pad
=
pad
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Spectral Centroid of size (..., time).
"""
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
self
.
win_length
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment