Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
56ab0368
Unverified
Commit
56ab0368
authored
Jul 14, 2021
by
Joel Frank
Committed by
GitHub
Jul 13, 2021
Browse files
MFCC test refactor (#1618)
parent
0e513208
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
20 deletions
+37
-20
test/torchaudio_unittest/transforms/transforms_test.py
test/torchaudio_unittest/transforms/transforms_test.py
+37
-20
No files found.
test/torchaudio_unittest/transforms/transforms_test.py
View file @
56ab0368
...
@@ -127,41 +127,58 @@ class Tester(common_utils.TorchaudioTestCase):
...
@@ -127,41 +127,58 @@ class Tester(common_utils.TorchaudioTestCase):
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
def
test_mfcc
(
self
):
def
test_mfcc_defaults
(
self
):
audio_orig
=
self
.
waveform
.
clone
()
"""Check the default configuration of the MFCC transform.
audio_scaled
=
self
.
scale
(
audio_orig
)
# (1, 16000)
"""
sample_rate
=
16000
sample_rate
=
16000
audio
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
n_mfcc
=
40
n_mfcc
=
40
n_mels
=
128
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
)
norm
=
'ortho'
)
# check defaults
torch_mfcc
=
mfcc_transform
(
audio
)
# (1, 40, 81)
torch_mfcc
=
mfcc_transform
(
audio_scaled
)
# (1, 40, 321)
self
.
assertEqual
(
torch_mfcc
.
dim
(),
3
)
self
.
assertTrue
(
torch_mfcc
.
dim
()
==
3
)
self
.
assertEqual
(
torch_mfcc
.
shape
[
1
],
n_mfcc
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
1
]
==
n_mfcc
)
self
.
assertEqual
(
torch_mfcc
.
shape
[
2
],
81
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
2
]
==
321
)
# check melkwargs are passed through
def
test_mfcc_kwargs_passthrough
(
self
):
"""Check kwargs get correctly passed to the MelSpectrogram transform.
"""
sample_rate
=
16000
audio
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
n_mfcc
=
40
melkwargs
=
{
'win_length'
:
200
}
melkwargs
=
{
'win_length'
:
200
}
mfcc_transform2
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
,
norm
=
'ortho'
,
melkwargs
=
melkwargs
)
melkwargs
=
melkwargs
)
torch_mfcc2
=
mfcc_transform2
(
audio_scaled
)
# (1, 40, 641)
torch_mfcc
=
mfcc_transform
(
audio
)
# (1, 40, 161)
self
.
assertTrue
(
torch_mfcc2
.
shape
[
2
]
==
641
)
self
.
assertEqual
(
torch_mfcc
.
shape
[
2
],
161
)
def
test_mfcc_norms
(
self
):
"""Check if MFCC-DCT norms work correctly.
"""
sample_rate
=
16000
audio
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
n_mfcc
=
40
n_mels
=
128
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
)
# check norms work correctly
# check norms work correctly
mfcc_transform_norm_none
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
mfcc_transform_norm_none
=
torchaudio
.
transforms
.
MFCC
(
sample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
n_mfcc
=
n_mfcc
,
norm
=
None
)
norm
=
None
)
torch_mfcc_norm_none
=
mfcc_transform_norm_none
(
audio
_scaled
)
# (1, 40,
32
1)
torch_mfcc_norm_none
=
mfcc_transform_norm_none
(
audio
)
# (1, 40,
8
1)
norm_check
=
torch_mfcc
.
clone
(
)
norm_check
=
mfcc_transform
(
audio
)
norm_check
[:,
0
,
:]
*=
math
.
sqrt
(
n_mels
)
*
2
norm_check
[:,
0
,
:]
*=
math
.
sqrt
(
n_mels
)
*
2
norm_check
[:,
1
:,
:]
*=
math
.
sqrt
(
n_mels
/
2
)
*
2
norm_check
[:,
1
:,
:]
*=
math
.
sqrt
(
n_mels
/
2
)
*
2
self
.
assert
True
(
torch_mfcc_norm_none
.
allclose
(
norm_check
)
)
self
.
assert
Equal
(
torch_mfcc_norm_none
,
norm_check
)
def
test_resample_size
(
self
):
def
test_resample_size
(
self
):
input_path
=
common_utils
.
get_asset_path
(
'sinewave.wav'
)
input_path
=
common_utils
.
get_asset_path
(
'sinewave.wav'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment