Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
93cc6da7
Unverified
Commit
93cc6da7
authored
May 15, 2020
by
moto
Committed by
GitHub
May 15, 2020
Browse files
Adopt PyTorch's test util to librosa compatibilities test (#646)
parent
6fc8953c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
80 additions
and
86 deletions
+80
-86
test/test_librosa_compatibility.py
test/test_librosa_compatibility.py
+80
-86
No files found.
test/test_librosa_compatibility.py
View file @
93cc6da7
...
@@ -3,6 +3,7 @@ import os
...
@@ -3,6 +3,7 @@ import os
import
unittest
import
unittest
import
torch
import
torch
from
torch.testing._internal.common_utils
import
TestCase
import
torchaudio
import
torchaudio
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
from
torchaudio.common_utils
import
IMPORT_LIBROSA
from
torchaudio.common_utils
import
IMPORT_LIBROSA
...
@@ -17,15 +18,8 @@ import pytest
...
@@ -17,15 +18,8 @@ import pytest
import
common_utils
import
common_utils
class
_LibrosaMixin
:
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
"Librosa not available"
)
"""Automatically skip tests if librosa is not available"""
class
TestFunctional
(
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
if
not
IMPORT_LIBROSA
:
raise
unittest
.
SkipTest
(
'Librosa not available'
)
class
TestFunctional
(
_LibrosaMixin
,
unittest
.
TestCase
):
"""Test suite for functions in `functional` module."""
"""Test suite for functions in `functional` module."""
def
test_griffinlim
(
self
):
def
test_griffinlim
(
self
):
# NOTE: This test is flaky without a fixed random seed
# NOTE: This test is flaky without a fixed random seed
...
@@ -51,7 +45,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
...
@@ -51,7 +45,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
momentum
=
momentum
,
init
=
init
,
length
=
length
)
momentum
=
momentum
,
init
=
init
,
length
=
length
)
lr_out
=
torch
.
from_numpy
(
lr_out
).
unsqueeze
(
0
)
lr_out
=
torch
.
from_numpy
(
lr_out
).
unsqueeze
(
0
)
torch
.
testing
.
assert_allclose
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
def
_test_create_fb
(
self
,
n_mels
=
40
,
sample_rate
=
22050
,
n_fft
=
2048
,
fmin
=
0.0
,
fmax
=
8000.0
,
norm
=
None
):
def
_test_create_fb
(
self
,
n_mels
=
40
,
sample_rate
=
22050
,
n_fft
=
2048
,
fmin
=
0.0
,
fmax
=
8000.0
,
norm
=
None
):
librosa_fb
=
librosa
.
filters
.
mel
(
sr
=
sample_rate
,
librosa_fb
=
librosa
.
filters
.
mel
(
sr
=
sample_rate
,
...
@@ -69,8 +63,8 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
...
@@ -69,8 +63,8 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
norm
=
norm
)
norm
=
norm
)
for
i_mel_bank
in
range
(
n_mels
):
for
i_mel_bank
in
range
(
n_mels
):
torch
.
testing
.
assert_allclose
(
fb
[:,
i_mel_bank
],
torch
.
tensor
(
librosa_fb
[
i_mel_bank
]),
self
.
assertEqual
(
atol
=
1e-4
,
rtol
=
1e-5
)
fb
[:,
i_mel_bank
],
torch
.
tensor
(
librosa_fb
[
i_mel_bank
]),
atol
=
1e-4
,
rtol
=
1e-5
)
def
test_create_fb
(
self
):
def
test_create_fb
(
self
):
self
.
_test_create_fb
()
self
.
_test_create_fb
()
...
@@ -101,7 +95,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
...
@@ -101,7 +95,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
lr_out
=
librosa
.
core
.
power_to_db
(
spec
.
numpy
())
lr_out
=
librosa
.
core
.
power_to_db
(
spec
.
numpy
())
lr_out
=
torch
.
from_numpy
(
lr_out
)
lr_out
=
torch
.
from_numpy
(
lr_out
)
torch
.
testing
.
assert_allclose
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
# Amplitude to DB
# Amplitude to DB
multiplier
=
20.0
multiplier
=
20.0
...
@@ -110,7 +104,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
...
@@ -110,7 +104,7 @@ class TestFunctional(_LibrosaMixin, unittest.TestCase):
lr_out
=
librosa
.
core
.
amplitude_to_db
(
spec
.
numpy
())
lr_out
=
librosa
.
core
.
amplitude_to_db
(
spec
.
numpy
())
lr_out
=
torch
.
from_numpy
(
lr_out
)
lr_out
=
torch
.
from_numpy
(
lr_out
)
torch
.
testing
.
assert_allclose
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
ta_out
,
lr_out
,
atol
=
5e-5
,
rtol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
'complex_specgrams'
,
[
@
pytest
.
mark
.
parametrize
(
'complex_specgrams'
,
[
...
@@ -161,7 +155,10 @@ def _load_audio_asset(*asset_paths, **kwargs):
...
@@ -161,7 +155,10 @@ def _load_audio_asset(*asset_paths, **kwargs):
return
sound
,
sample_rate
return
sound
,
sample_rate
def
_test_compatibilities
(
n_fft
,
hop_length
,
power
,
n_mels
,
n_mfcc
,
sample_rate
):
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
"Librosa not available"
)
class
TestTransforms
(
TestCase
):
"""Test suite for functions in `transforms` module."""
def
assert_compatibilities
(
self
,
n_fft
,
hop_length
,
power
,
n_mels
,
n_mfcc
,
sample_rate
):
sound
,
sample_rate
=
_load_audio_asset
(
'sinewave.wav'
)
sound
,
sample_rate
=
_load_audio_asset
(
'sinewave.wav'
)
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
# (64000)
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
# (64000)
...
@@ -172,7 +169,7 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
...
@@ -172,7 +169,7 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
y
=
sound_librosa
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
power
=
power
)
y
=
sound_librosa
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
power
=
power
)
out_torch
=
spect_transform
(
sound
).
squeeze
().
cpu
()
out_torch
=
spect_transform
(
sound
).
squeeze
().
cpu
()
torch
.
testing
.
assert_allclose
(
out_torch
,
torch
.
from_numpy
(
out_librosa
),
atol
=
1e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
out_torch
,
torch
.
from_numpy
(
out_librosa
),
atol
=
1e-5
,
rtol
=
1e-5
)
# test mel spectrogram
# test mel spectrogram
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
...
@@ -183,24 +180,24 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
...
@@ -183,24 +180,24 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
None
)
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
None
)
librosa_mel_tensor
=
torch
.
from_numpy
(
librosa_mel
)
librosa_mel_tensor
=
torch
.
from_numpy
(
librosa_mel
)
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
torch
.
testing
.
assert_allclose
(
self
.
assertEqual
(
torch_mel
.
type
(
librosa_mel_tensor
.
dtype
),
librosa_mel_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
torch_mel
.
type
(
librosa_mel_tensor
.
dtype
),
librosa_mel_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
# test s2db
# test s2db
power_to_db_transform
=
torchaudio
.
transforms
.
AmplitudeToDB
(
'power'
,
80.
)
power_to_db_transform
=
torchaudio
.
transforms
.
AmplitudeToDB
(
'power'
,
80.
)
power_to_db_torch
=
power_to_db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
()
power_to_db_torch
=
power_to_db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
()
power_to_db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
out_librosa
)
power_to_db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
out_librosa
)
torch
.
testing
.
assert_allclose
(
power_to_db_torch
,
torch
.
from_numpy
(
power_to_db_librosa
),
atol
=
5e-3
,
rtol
=
1e-5
)
self
.
assertEqual
(
power_to_db_torch
,
torch
.
from_numpy
(
power_to_db_librosa
),
atol
=
5e-3
,
rtol
=
1e-5
)
mag_to_db_transform
=
torchaudio
.
transforms
.
AmplitudeToDB
(
'magnitude'
,
80.
)
mag_to_db_transform
=
torchaudio
.
transforms
.
AmplitudeToDB
(
'magnitude'
,
80.
)
mag_to_db_torch
=
mag_to_db_transform
(
torch
.
abs
(
sound
)).
squeeze
().
cpu
()
mag_to_db_torch
=
mag_to_db_transform
(
torch
.
abs
(
sound
)).
squeeze
().
cpu
()
mag_to_db_librosa
=
librosa
.
core
.
spectrum
.
amplitude_to_db
(
sound_librosa
)
mag_to_db_librosa
=
librosa
.
core
.
spectrum
.
amplitude_to_db
(
sound_librosa
)
torch
.
testing
.
assert_allclose
(
mag_to_db_torch
,
torch
.
from_numpy
(
mag_to_db_librosa
),
atol
=
5e-3
,
rtol
=
1e-5
)
self
.
assertEqual
(
mag_to_db_torch
,
torch
.
from_numpy
(
mag_to_db_librosa
),
atol
=
5e-3
,
rtol
=
1e-5
)
power_to_db_torch
=
power_to_db_transform
(
melspect_transform
(
sound
)).
squeeze
().
cpu
()
power_to_db_torch
=
power_to_db_transform
(
melspect_transform
(
sound
)).
squeeze
().
cpu
()
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
librosa_mel
)
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
librosa_mel
)
db_librosa_tensor
=
torch
.
from_numpy
(
db_librosa
)
db_librosa_tensor
=
torch
.
from_numpy
(
db_librosa
)
torch
.
testing
.
assert_allclose
(
self
.
assertEqual
(
power_to_db_torch
.
type
(
db_librosa_tensor
.
dtype
),
db_librosa_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
power_to_db_torch
.
type
(
db_librosa_tensor
.
dtype
),
db_librosa_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
# test MFCC
# test MFCC
...
@@ -222,12 +219,9 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
...
@@ -222,12 +219,9 @@ def _test_compatibilities(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate)
librosa_mfcc_tensor
=
torch
.
from_numpy
(
librosa_mfcc
)
librosa_mfcc_tensor
=
torch
.
from_numpy
(
librosa_mfcc
)
torch_mfcc
=
mfcc_transform
(
sound
).
squeeze
().
cpu
()
torch_mfcc
=
mfcc_transform
(
sound
).
squeeze
().
cpu
()
torch
.
testing
.
assert_allclose
(
self
.
assertEqual
(
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
,
rtol
=
1e-5
)
class
TestTransforms
(
_LibrosaMixin
,
unittest
.
TestCase
):
"""Test suite for functions in `transforms` module."""
def
test_basics1
(
self
):
def
test_basics1
(
self
):
kwargs
=
{
kwargs
=
{
'n_fft'
:
400
,
'n_fft'
:
400
,
...
@@ -237,7 +231,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
...
@@ -237,7 +231,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc'
:
40
,
'n_mfcc'
:
40
,
'sample_rate'
:
16000
'sample_rate'
:
16000
}
}
_tes
t_compatibilities
(
**
kwargs
)
self
.
asser
t_compatibilities
(
**
kwargs
)
def
test_basics2
(
self
):
def
test_basics2
(
self
):
kwargs
=
{
kwargs
=
{
...
@@ -248,7 +242,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
...
@@ -248,7 +242,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc'
:
20
,
'n_mfcc'
:
20
,
'sample_rate'
:
16000
'sample_rate'
:
16000
}
}
_tes
t_compatibilities
(
**
kwargs
)
self
.
asser
t_compatibilities
(
**
kwargs
)
# NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
# NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
@
unittest
.
skipIf
(
'CI'
in
os
.
environ
,
'Test is known to fail on CI'
)
@
unittest
.
skipIf
(
'CI'
in
os
.
environ
,
'Test is known to fail on CI'
)
...
@@ -261,7 +255,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
...
@@ -261,7 +255,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc'
:
50
,
'n_mfcc'
:
50
,
'sample_rate'
:
24000
'sample_rate'
:
24000
}
}
_tes
t_compatibilities
(
**
kwargs
)
self
.
asser
t_compatibilities
(
**
kwargs
)
def
test_basics4
(
self
):
def
test_basics4
(
self
):
kwargs
=
{
kwargs
=
{
...
@@ -272,7 +266,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
...
@@ -272,7 +266,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
'n_mfcc'
:
40
,
'n_mfcc'
:
40
,
'sample_rate'
:
16000
'sample_rate'
:
16000
}
}
_tes
t_compatibilities
(
**
kwargs
)
self
.
asser
t_compatibilities
(
**
kwargs
)
@
unittest
.
skipIf
(
"sox"
not
in
common_utils
.
BACKENDS
,
"sox not available"
)
@
unittest
.
skipIf
(
"sox"
not
in
common_utils
.
BACKENDS
,
"sox not available"
)
@
common_utils
.
AudioBackendScope
(
"sox"
)
@
common_utils
.
AudioBackendScope
(
"sox"
)
...
@@ -295,7 +289,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
...
@@ -295,7 +289,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
S
=
spec_lr
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
S
=
spec_lr
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
n_fft
,
center
=
True
,
window
=
'hann'
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
None
)
win_length
=
n_fft
,
center
=
True
,
window
=
'hann'
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
None
)
# Note: Using relaxed rtol instead of atol
# Note: Using relaxed rtol instead of atol
torch
.
testing
.
assert_allclose
(
melspec_ta
,
torch
.
from_numpy
(
melspec_lr
[
None
,
...]),
atol
=
1e-8
,
rtol
=
1e-3
)
self
.
assertEqual
(
melspec_ta
,
torch
.
from_numpy
(
melspec_lr
[
None
,
...]),
atol
=
1e-8
,
rtol
=
1e-3
)
def
test_InverseMelScale
(
self
):
def
test_InverseMelScale
(
self
):
"""InverseMelScale transform is comparable to that of librosa"""
"""InverseMelScale transform is comparable to that of librosa"""
...
@@ -338,7 +332,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
...
@@ -338,7 +332,7 @@ class TestTransforms(_LibrosaMixin, unittest.TestCase):
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/366 for the discussion of the choice of algorithm
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# https://github.com/pytorch/audio/pull/448/files#r385747021 for the distribution of P-inf
# distance over frequencies.
# distance over frequencies.
torch
.
testing
.
assert_allclose
(
spec_ta
,
spec_lr
,
atol
=
threshold
,
rtol
=
1e-5
)
self
.
assertEqual
(
spec_ta
,
spec_lr
,
atol
=
threshold
,
rtol
=
1e-5
)
threshold
=
1700.0
threshold
=
1700.0
# This threshold was choosen empirically, based on the following observations
# This threshold was choosen empirically, based on the following observations
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment