Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b29a4639
Commit
b29a4639
authored
Jul 24, 2019
by
jamarshon
Committed by
cpuhrsch
Jul 24, 2019
Browse files
[BC] Standardization of Transforms/Functionals (#152)
parent
2271a7ae
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
365 additions
and
839 deletions
+365
-839
test/test_functional.py
test/test_functional.py
+2
-52
test/test_jit.py
test/test_jit.py
+12
-92
test/test_transforms.py
test/test_transforms.py
+74
-142
torchaudio/functional.py
torchaudio/functional.py
+115
-240
torchaudio/transforms.py
torchaudio/transforms.py
+162
-313
No files found.
test/test_functional.py
View file @
b29a4639
...
@@ -2,6 +2,8 @@ import math
...
@@ -2,6 +2,8 @@ import math
import
torch
import
torch
import
torchaudio
import
torchaudio
import
torchaudio.functional
as
F
import
pytest
import
unittest
import
unittest
import
test.common_utils
import
test.common_utils
...
@@ -11,10 +13,6 @@ if IMPORT_LIBROSA:
...
@@ -11,10 +13,6 @@ if IMPORT_LIBROSA:
import
numpy
as
np
import
numpy
as
np
import
librosa
import
librosa
import
pytest
import
torchaudio.functional
as
F
xfail
=
pytest
.
mark
.
xfail
class
TestFunctional
(
unittest
.
TestCase
):
class
TestFunctional
(
unittest
.
TestCase
):
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
data_sizes
=
[(
2
,
20
),
(
3
,
15
),
(
4
,
10
)]
...
@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):
...
@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return
(
signal_len
+
2
*
pad
-
fft_len
+
hop_length
)
//
hop_length
return
(
signal_len
+
2
*
pad
-
fft_len
+
hop_length
)
//
hop_length
@
pytest
.
mark
.
parametrize
(
'fft_length'
,
[
512
])
@
pytest
.
mark
.
parametrize
(
'hop_length'
,
[
256
])
@
pytest
.
mark
.
parametrize
(
'waveform'
,
[
(
torch
.
randn
(
1
,
100000
)),
(
torch
.
randn
(
1
,
2
,
100000
)),
pytest
.
param
(
torch
.
randn
(
1
,
100
),
marks
=
xfail
(
raises
=
RuntimeError
)),
])
@
pytest
.
mark
.
parametrize
(
'pad_mode'
,
[
# 'constant',
'reflect'
,
])
@
unittest
.
skipIf
(
not
IMPORT_LIBROSA
,
'Librosa is not available'
)
def
test_stft
(
waveform
,
fft_length
,
hop_length
,
pad_mode
):
"""
Test STFT for multi-channel signals.
Padding: Value in having padding outside of torch.stft?
"""
pad
=
fft_length
//
2
window
=
torch
.
hann_window
(
fft_length
)
complex_spec
=
F
.
stft
(
waveform
,
fft_length
=
fft_length
,
hop_length
=
hop_length
,
window
=
window
,
pad_mode
=
pad_mode
)
mag_spec
,
phase_spec
=
F
.
magphase
(
complex_spec
)
# == Test shape
expected_size
=
list
(
waveform
.
size
()[:
-
1
])
expected_size
+=
[
fft_length
//
2
+
1
,
_num_stft_bins
(
waveform
.
size
(
-
1
),
fft_length
,
hop_length
,
pad
),
2
]
assert
complex_spec
.
dim
()
==
waveform
.
dim
()
+
2
assert
complex_spec
.
size
()
==
torch
.
Size
(
expected_size
)
# == Test values
fft_config
=
dict
(
n_fft
=
fft_length
,
hop_length
=
hop_length
,
pad_mode
=
pad_mode
)
# note that librosa *automatically* pad with fft_length // 2.
expected_complex_spec
=
np
.
apply_along_axis
(
librosa
.
stft
,
-
1
,
waveform
.
numpy
(),
**
fft_config
)
expected_mag_spec
,
_
=
librosa
.
magphase
(
expected_complex_spec
)
# Convert torch to np.complex
complex_spec
=
complex_spec
.
numpy
()
complex_spec
=
complex_spec
[...,
0
]
+
1j
*
complex_spec
[...,
1
]
assert
np
.
allclose
(
complex_spec
,
expected_complex_spec
,
atol
=
1e-5
)
assert
np
.
allclose
(
mag_spec
.
numpy
(),
expected_mag_spec
,
atol
=
1e-5
)
@
pytest
.
mark
.
parametrize
(
'rate'
,
[
0.5
,
1.01
,
1.3
])
@
pytest
.
mark
.
parametrize
(
'rate'
,
[
0.5
,
1.01
,
1.3
])
@
pytest
.
mark
.
parametrize
(
'complex_specgrams'
,
[
@
pytest
.
mark
.
parametrize
(
'complex_specgrams'
,
[
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
),
torch
.
randn
(
1
,
2
,
1025
,
400
,
2
),
...
...
test/test_jit.py
View file @
b29a4639
...
@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase):
...
@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
def
test_torchscript_scale
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
factor
):
# type: (Tensor, int) -> Tensor
return
F
.
scale
(
tensor
,
factor
)
tensor
=
torch
.
rand
((
10
,
1
))
factor
=
2
jit_out
=
jit_method
(
tensor
,
factor
)
py_out
=
F
.
scale
(
tensor
,
factor
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_scale
(
self
):
tensor
=
torch
.
rand
((
10
,
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
Scale
)
def
test_torchscript_pad_trim
(
self
):
def
test_torchscript_pad_trim
(
self
):
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
):
def
jit_method
(
tensor
,
max_len
,
fill_value
):
# type: (Tensor, int,
int, int,
float) -> Tensor
# type: (Tensor, int, float) -> Tensor
return
F
.
pad_trim
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
)
return
F
.
pad_trim
(
tensor
,
max_len
,
fill_value
)
tensor
=
torch
.
rand
((
10
,
1
))
tensor
=
torch
.
rand
((
1
,
10
))
ch_dim
=
1
max_len
=
5
max_len
=
5
len_dim
=
0
fill_value
=
3.
fill_value
=
3.
jit_out
=
jit_method
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
)
jit_out
=
jit_method
(
tensor
,
max_len
,
fill_value
)
py_out
=
F
.
pad_trim
(
tensor
,
ch_dim
,
max_len
,
len_dim
,
fill_value
)
py_out
=
F
.
pad_trim
(
tensor
,
max_len
,
fill_value
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
...
@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase):
...
@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase):
self
.
_test_script_module
(
tensor
,
transforms
.
PadTrim
,
max_len
)
self
.
_test_script_module
(
tensor
,
transforms
.
PadTrim
,
max_len
)
def
test_torchscript_downmix_mono
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
ch_dim
):
# type: (Tensor, int) -> Tensor
return
F
.
downmix_mono
(
tensor
,
ch_dim
)
tensor
=
torch
.
rand
((
10
,
1
))
ch_dim
=
1
jit_out
=
jit_method
(
tensor
,
ch_dim
)
py_out
=
F
.
downmix_mono
(
tensor
,
ch_dim
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_downmix_mono
(
self
):
tensor
=
torch
.
rand
((
1
,
10
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
DownmixMono
)
def
test_torchscript_LC2CL
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
):
# type: (Tensor) -> Tensor
return
F
.
LC2CL
(
tensor
)
tensor
=
torch
.
rand
((
10
,
1
))
jit_out
=
jit_method
(
tensor
)
py_out
=
F
.
LC2CL
(
tensor
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_LC2CL
(
self
):
tensor
=
torch
.
rand
((
10
,
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
LC2CL
)
def
test_torchscript_spectrogram
(
self
):
def
test_torchscript_spectrogram
(
self
):
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
jit_method
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
def
jit_method
(
sig
,
pad
,
window
,
n_fft
,
hop
,
ws
,
power
,
normalize
):
...
@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase):
...
@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
return
F
.
spectrogram_to_DB
(
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
)
return
F
.
spectrogram_to_DB
(
spec
,
multiplier
,
amin
,
db_multiplier
,
top_db
)
spec
=
torch
.
rand
((
10
,
1
))
spec
=
torch
.
rand
((
6
,
20
1
))
multiplier
=
10.
multiplier
=
10.
amin
=
1e-10
amin
=
1e-10
db_multiplier
=
0.
db_multiplier
=
0.
...
@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase):
...
@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase):
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_SpectrogramToDB
(
self
):
def
test_scriptmodule_SpectrogramToDB
(
self
):
spec
=
torch
.
rand
((
10
,
1
),
device
=
"cuda"
)
spec
=
torch
.
rand
((
6
,
20
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
spec
,
transforms
.
SpectrogramToDB
)
self
.
_test_script_module
(
spec
,
transforms
.
SpectrogramToDB
)
...
@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase):
...
@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase):
self
.
_test_script_module
(
tensor
,
transforms
.
MelSpectrogram
)
self
.
_test_script_module
(
tensor
,
transforms
.
MelSpectrogram
)
def
test_torchscript_BLC2CBL
(
self
):
@
torch
.
jit
.
script
def
jit_method
(
tensor
):
# type: (Tensor) -> Tensor
return
F
.
BLC2CBL
(
tensor
)
tensor
=
torch
.
rand
((
10
,
1000
,
1
))
jit_out
=
jit_method
(
tensor
)
py_out
=
F
.
BLC2CBL
(
tensor
)
self
.
assertTrue
(
torch
.
allclose
(
jit_out
,
py_out
))
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_BLC2CBL
(
self
):
tensor
=
torch
.
rand
((
10
,
1000
,
1
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
BLC2CBL
)
def
test_torchscript_mu_law_encoding
(
self
):
def
test_torchscript_mu_law_encoding
(
self
):
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
jit_method
(
tensor
,
qc
):
def
jit_method
(
tensor
,
qc
):
# type: (Tensor, int) -> Tensor
# type: (Tensor, int) -> Tensor
return
F
.
mu_law_encoding
(
tensor
,
qc
)
return
F
.
mu_law_encoding
(
tensor
,
qc
)
tensor
=
torch
.
rand
((
1
0
,
1
))
tensor
=
torch
.
rand
((
1
,
1
0
))
qc
=
256
qc
=
256
jit_out
=
jit_method
(
tensor
,
qc
)
jit_out
=
jit_method
(
tensor
,
qc
)
...
@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase):
...
@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase):
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_MuLawEncoding
(
self
):
def
test_scriptmodule_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
0
,
1
),
device
=
"cuda"
)
tensor
=
torch
.
rand
((
1
,
1
0
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
MuLawEncoding
)
self
.
_test_script_module
(
tensor
,
transforms
.
MuLawEncoding
)
...
@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase):
...
@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, int) -> Tensor
# type: (Tensor, int) -> Tensor
return
F
.
mu_law_expanding
(
tensor
,
qc
)
return
F
.
mu_law_expanding
(
tensor
,
qc
)
tensor
=
torch
.
rand
((
1
0
,
1
))
tensor
=
torch
.
rand
((
1
,
1
0
))
qc
=
256
qc
=
256
jit_out
=
jit_method
(
tensor
,
qc
)
jit_out
=
jit_method
(
tensor
,
qc
)
...
@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase):
...
@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase):
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
@
unittest
.
skipIf
(
not
RUN_CUDA
,
"no CUDA"
)
def
test_scriptmodule_MuLawExpanding
(
self
):
def
test_scriptmodule_MuLawExpanding
(
self
):
tensor
=
torch
.
rand
((
1
0
,
1
),
device
=
"cuda"
)
tensor
=
torch
.
rand
((
1
,
1
0
),
device
=
"cuda"
)
self
.
_test_script_module
(
tensor
,
transforms
.
MuLawExpanding
)
self
.
_test_script_module
(
tensor
,
transforms
.
MuLawExpanding
)
...
...
test/test_transforms.py
View file @
b29a4639
...
@@ -19,191 +19,123 @@ if IMPORT_SCIPY:
...
@@ -19,191 +19,123 @@ if IMPORT_SCIPY:
class
Tester
(
unittest
.
TestCase
):
class
Tester
(
unittest
.
TestCase
):
# create a sinewave signal for testing
# create a sinewave signal for testing
s
r
=
16000
s
ample_rate
=
16000
freq
=
440
freq
=
440
volume
=
.
3
volume
=
.
3
sig
=
(
torch
.
cos
(
2
*
math
.
pi
*
torch
.
arange
(
0
,
4
*
s
r
).
float
()
*
freq
/
s
r
))
waveform
=
(
torch
.
cos
(
2
*
math
.
pi
*
torch
.
arange
(
0
,
4
*
s
ample_rate
).
float
()
*
freq
/
s
ample_rate
))
sig
.
unsqueeze_
(
1
)
# (64000
, 1
)
waveform
.
unsqueeze_
(
0
)
# (
1,
64000)
sig
=
(
sig
*
volume
*
2
**
31
).
long
()
waveform
=
(
waveform
*
volume
*
2
**
31
).
long
()
# file for stereo stft test
# file for stereo stft test
test_dirpath
,
test_dir
=
test
.
common_utils
.
create_temp_assets_dir
()
test_dirpath
,
test_dir
=
test
.
common_utils
.
create_temp_assets_dir
()
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
"assets"
,
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
'assets'
,
"steam-train-whistle-daniel_simon.mp3"
)
'steam-train-whistle-daniel_simon.mp3'
)
def
test_scale
(
self
):
audio_orig
=
self
.
sig
.
clone
()
result
=
transforms
.
Scale
()(
audio_orig
)
self
.
assertTrue
(
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
)
maxminmax
=
max
(
abs
(
audio_orig
.
min
()),
abs
(
audio_orig
.
max
())).
item
()
result
=
transforms
.
Scale
(
factor
=
maxminmax
)(
audio_orig
)
self
.
assertTrue
((
result
.
min
()
==
-
1.
or
result
.
max
()
==
1.
)
and
def
scale
(
self
,
waveform
,
factor
=
float
(
2
**
31
)):
result
.
min
()
>=
-
1.
and
result
.
max
()
<=
1.
)
# scales a waveform by a factor
if
not
waveform
.
is_floating_point
():
repr_test
=
transforms
.
Scal
e
()
waveform
=
waveform
.
to
(
torch
.
get_default_dtyp
e
()
)
self
.
assertTrue
(
repr_test
.
__repr__
())
return
waveform
/
factor
def
test_pad_trim
(
self
):
def
test_pad_trim
(
self
):
audio_orig
=
self
.
sig
.
clone
()
waveform
=
self
.
waveform
.
clone
()
length_orig
=
audio_orig
.
size
(
0
)
length_orig
=
waveform
.
size
(
1
)
length_new
=
int
(
length_orig
*
1.2
)
length_new
=
int
(
length_orig
*
1.2
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)(
audio_orig
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
)(
waveform
)
self
.
assertEqual
(
result
.
size
(
0
),
length_new
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
True
)(
audio_orig
.
transpose
(
0
,
1
))
self
.
assertEqual
(
result
.
size
(
1
),
length_new
)
self
.
assertEqual
(
result
.
size
(
1
),
length_new
)
audio_orig
=
self
.
sig
.
clone
()
length_orig
=
audio_orig
.
size
(
0
)
length_new
=
int
(
length_orig
*
0.8
)
length_new
=
int
(
length_orig
*
0.8
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)(
audio_orig
)
result
=
transforms
.
PadTrim
(
max_len
=
length_new
)(
waveform
)
self
.
assertEqual
(
result
.
size
(
1
),
length_new
)
self
.
assertEqual
(
result
.
size
(
0
),
length_new
)
repr_test
=
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
)
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_downmix_mono
(
self
):
audio_L
=
self
.
sig
.
clone
()
audio_R
=
self
.
sig
.
clone
()
R_idx
=
int
(
audio_R
.
size
(
0
)
*
0.1
)
audio_R
=
torch
.
cat
((
audio_R
[
R_idx
:],
audio_R
[:
R_idx
]))
audio_Stereo
=
torch
.
cat
((
audio_L
,
audio_R
),
dim
=
1
)
self
.
assertTrue
(
audio_Stereo
.
size
(
1
)
==
2
)
result
=
transforms
.
DownmixMono
(
channels_first
=
False
)(
audio_Stereo
)
self
.
assertTrue
(
result
.
size
(
1
)
==
1
)
repr_test
=
transforms
.
DownmixMono
(
channels_first
=
False
)
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_lc2cl
(
self
):
audio
=
self
.
sig
.
clone
()
result
=
transforms
.
LC2CL
()(
audio
)
self
.
assertTrue
(
result
.
size
()[::
-
1
]
==
audio
.
size
())
repr_test
=
transforms
.
LC2CL
()
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_compose
(
self
):
audio_orig
=
self
.
sig
.
clone
()
length_orig
=
audio_orig
.
size
(
0
)
length_new
=
int
(
length_orig
*
1.2
)
maxminmax
=
max
(
abs
(
audio_orig
.
min
()),
abs
(
audio_orig
.
max
())).
item
()
tset
=
(
transforms
.
Scale
(
factor
=
maxminmax
),
transforms
.
PadTrim
(
max_len
=
length_new
,
channels_first
=
False
))
result
=
transforms
.
Compose
(
tset
)(
audio_orig
)
self
.
assertTrue
(
max
(
abs
(
result
.
min
()),
abs
(
result
.
max
()))
==
1.
)
self
.
assertTrue
(
result
.
size
(
0
)
==
length_new
)
repr_test
=
transforms
.
Compose
(
tset
)
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_mu_law_companding
(
self
):
def
test_mu_law_companding
(
self
):
quantization_channels
=
256
quantization_channels
=
256
sig
=
self
.
sig
.
clone
()
waveform
=
self
.
waveform
.
clone
()
sig
=
sig
/
torch
.
abs
(
sig
).
max
()
waveform
/=
torch
.
abs
(
waveform
).
max
()
self
.
assertTrue
(
sig
.
min
()
>=
-
1.
and
sig
.
max
()
<=
1.
)
self
.
assertTrue
(
waveform
.
min
()
>=
-
1.
and
waveform
.
max
()
<=
1.
)
sig
_mu
=
transforms
.
MuLawEncoding
(
quantization_channels
)(
sig
)
waveform
_mu
=
transforms
.
MuLawEncoding
(
quantization_channels
)(
waveform
)
self
.
assertTrue
(
sig
_mu
.
min
()
>=
0.
and
sig
.
max
()
<=
quantization_channels
)
self
.
assertTrue
(
waveform
_mu
.
min
()
>=
0.
and
waveform_mu
.
max
()
<=
quantization_channels
)
sig_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
sig_mu
)
waveform_exp
=
transforms
.
MuLawExpanding
(
quantization_channels
)(
waveform_mu
)
self
.
assertTrue
(
sig_exp
.
min
()
>=
-
1.
and
sig_exp
.
max
()
<=
1.
)
self
.
assertTrue
(
waveform_exp
.
min
()
>=
-
1.
and
waveform_exp
.
max
()
<=
1.
)
repr_test
=
transforms
.
MuLawEncoding
(
quantization_channels
)
self
.
assertTrue
(
repr_test
.
__repr__
())
repr_test
=
transforms
.
MuLawExpanding
(
quantization_channels
)
self
.
assertTrue
(
repr_test
.
__repr__
())
def
test_mel2
(
self
):
def
test_mel2
(
self
):
top_db
=
80.
top_db
=
80.
s2db
=
transforms
.
SpectrogramToDB
(
"
power
"
,
top_db
)
s2db
=
transforms
.
SpectrogramToDB
(
'
power
'
,
top_db
)
audio_orig
=
self
.
sig
.
clone
()
# (16000, 1)
waveform
=
self
.
waveform
.
clone
()
# (1, 16000)
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
waveform_scaled
=
self
.
scale
(
waveform
)
# (1, 16000)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
mel_transform
=
transforms
.
MelSpectrogram
()
mel_transform
=
transforms
.
MelSpectrogram
()
# check defaults
# check defaults
spectrogram_torch
=
s2db
(
mel_transform
(
audio
_scaled
))
# (1,
319, 40
)
spectrogram_torch
=
s2db
(
mel_transform
(
waveform
_scaled
))
# (1,
128, 321
)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertEqual
(
spectrogram_torch
.
size
(
-
1
),
mel_transform
.
n_mels
)
self
.
assertEqual
(
spectrogram_torch
.
size
(
1
),
mel_transform
.
n_mels
)
# check correctness of filterbank conversion matrix
# check correctness of filterbank conversion matrix
self
.
assertTrue
(
mel_transform
.
f
m
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform
.
m
el_scale
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform
.
f
m
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
mel_transform
.
m
el_scale
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
# check options
# check options
kwargs
=
{
"window"
:
torch
.
hamming_window
,
"pad"
:
10
,
"ws"
:
500
,
"hop"
:
125
,
"n_fft"
:
800
,
"n_mels"
:
50
}
kwargs
=
{
'window_fn'
:
torch
.
hamming_window
,
'pad'
:
10
,
'win_length'
:
500
,
'hop_length'
:
125
,
'n_fft'
:
800
,
'n_mels'
:
50
}
mel_transform2
=
transforms
.
MelSpectrogram
(
**
kwargs
)
mel_transform2
=
transforms
.
MelSpectrogram
(
**
kwargs
)
spectrogram2_torch
=
s2db
(
mel_transform2
(
audio
_scaled
))
# (1, 50
6
, 5
0
)
spectrogram2_torch
=
s2db
(
mel_transform2
(
waveform
_scaled
))
# (1, 50, 5
13
)
self
.
assertTrue
(
spectrogram2_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram2_torch
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertEqual
(
spectrogram2_torch
.
size
(
-
1
),
mel_transform2
.
n_mels
)
self
.
assertEqual
(
spectrogram2_torch
.
size
(
1
),
mel_transform2
.
n_mels
)
self
.
assertTrue
(
mel_transform2
.
f
m
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform2
.
m
el_scale
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
mel_transform2
.
f
m
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
mel_transform2
.
m
el_scale
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
# check on multi-channel audio
# check on multi-channel audio
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
x_stereo
,
sr_stereo
=
torchaudio
.
load
(
self
.
test_filepath
)
# (2, 278756), 44100
spectrogram_stereo
=
s2db
(
mel_transform
(
x_stereo
))
spectrogram_stereo
=
s2db
(
mel_transform
(
x_stereo
))
# (2, 128, 1394)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
dim
()
==
3
)
self
.
assertTrue
(
spectrogram_stereo
.
size
(
0
)
==
2
)
self
.
assertTrue
(
spectrogram_stereo
.
size
(
0
)
==
2
)
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertTrue
(
spectrogram_torch
.
ge
(
spectrogram_torch
.
max
()
-
top_db
).
all
())
self
.
assertEqual
(
spectrogram_stereo
.
size
(
-
1
),
mel_transform
.
n_mels
)
self
.
assertEqual
(
spectrogram_stereo
.
size
(
1
),
mel_transform
.
n_mels
)
# check filterbank matrix creation
# check filterbank matrix creation
fb_matrix_transform
=
transforms
.
MelScale
(
n_mels
=
100
,
sr
=
16000
,
f_max
=
None
,
f_min
=
0.
,
n_stft
=
400
)
fb_matrix_transform
=
transforms
.
MelScale
(
n_mels
=
100
,
sample_rate
=
16000
,
f_min
=
0.
,
f_max
=
None
,
n_stft
=
400
)
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
le
(
1.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertTrue
(
fb_matrix_transform
.
fb
.
sum
(
1
).
ge
(
0.
).
all
())
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
self
.
assertEqual
(
fb_matrix_transform
.
fb
.
size
(),
(
400
,
100
))
def
test_mfcc
(
self
):
def
test_mfcc
(
self
):
audio_orig
=
self
.
sig
.
clone
()
audio_orig
=
self
.
waveform
.
clone
()
audio_scaled
=
transforms
.
Scale
()(
audio_orig
)
# (16000, 1)
audio_scaled
=
self
.
scale
(
audio_orig
)
# (1, 16000)
audio_scaled
=
transforms
.
LC2CL
()(
audio_scaled
)
# (1, 16000)
sample_rate
=
16000
sample_rate
=
16000
n_mfcc
=
40
n_mfcc
=
40
n_mels
=
128
n_mels
=
128
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
)
norm
=
'ortho'
)
# check defaults
# check defaults
torch_mfcc
=
mfcc_transform
(
audio_scaled
)
torch_mfcc
=
mfcc_transform
(
audio_scaled
)
# (1, 40, 321)
self
.
assertTrue
(
torch_mfcc
.
dim
()
==
3
)
self
.
assertTrue
(
torch_mfcc
.
dim
()
==
3
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
2
]
==
n_mfcc
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
1
]
==
n_mfcc
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
1
]
==
321
)
self
.
assertTrue
(
torch_mfcc
.
shape
[
2
]
==
321
)
# check melkwargs are passed through
# check melkwargs are passed through
melkwargs
=
{
'w
s
'
:
200
}
melkwargs
=
{
'w
in_length
'
:
200
}
mfcc_transform2
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
mfcc_transform2
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
,
norm
=
'ortho'
,
melkwargs
=
melkwargs
)
melkwargs
=
melkwargs
)
torch_mfcc2
=
mfcc_transform2
(
audio_scaled
)
torch_mfcc2
=
mfcc_transform2
(
audio_scaled
)
# (1, 40, 641)
self
.
assertTrue
(
torch_mfcc2
.
shape
[
1
]
==
641
)
self
.
assertTrue
(
torch_mfcc2
.
shape
[
2
]
==
641
)
# check norms work correctly
# check norms work correctly
mfcc_transform_norm_none
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
mfcc_transform_norm_none
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
n_mfcc
=
n_mfcc
,
norm
=
None
)
norm
=
None
)
torch_mfcc_norm_none
=
mfcc_transform_norm_none
(
audio_scaled
)
torch_mfcc_norm_none
=
mfcc_transform_norm_none
(
audio_scaled
)
# (1, 40, 321)
norm_check
=
torch_mfcc
.
clone
()
norm_check
=
torch_mfcc
.
clone
()
norm_check
[:,
:
,
0
]
*=
math
.
sqrt
(
n_mels
)
*
2
norm_check
[:,
0
,
:
]
*=
math
.
sqrt
(
n_mels
)
*
2
norm_check
[:,
:,
1
:]
*=
math
.
sqrt
(
n_mels
/
2
)
*
2
norm_check
[:,
1
:,
:]
*=
math
.
sqrt
(
n_mels
/
2
)
*
2
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
self
.
assertTrue
(
torch_mfcc_norm_none
.
allclose
(
norm_check
))
...
@@ -212,45 +144,45 @@ class Tester(unittest.TestCase):
...
@@ -212,45 +144,45 @@ class Tester(unittest.TestCase):
def
_test_librosa_consistency_helper
(
n_fft
,
hop_length
,
power
,
n_mels
,
n_mfcc
,
sample_rate
):
def
_test_librosa_consistency_helper
(
n_fft
,
hop_length
,
power
,
n_mels
,
n_mfcc
,
sample_rate
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
sound
,
sample_rate
=
torchaudio
.
load
(
input_path
)
sound
,
sample_rate
=
torchaudio
.
load
(
input_path
)
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
.
T
#
squeeze batch and channel first
sound_librosa
=
sound
.
cpu
().
numpy
().
squeeze
()
#
(64000)
# test core spectrogram
# test core spectrogram
spect_transform
=
torchaudio
.
transforms
.
Spectrogram
(
n_fft
=
n_fft
,
hop
=
hop_length
,
power
=
2
)
spect_transform
=
torchaudio
.
transforms
.
Spectrogram
(
n_fft
=
n_fft
,
hop
_length
=
hop_length
,
power
=
2
)
out_librosa
,
_
=
librosa
.
core
.
spectrum
.
_spectrogram
(
y
=
sound_librosa
,
out_librosa
,
_
=
librosa
.
core
.
spectrum
.
_spectrogram
(
y
=
sound_librosa
,
n_fft
=
n_fft
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
hop_length
=
hop_length
,
power
=
2
)
power
=
2
)
out_torch
=
spect_transform
(
sound
).
squeeze
().
cpu
()
.
t
()
out_torch
=
spect_transform
(
sound
).
squeeze
().
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
out_torch
,
torch
.
from_numpy
(
out_librosa
),
atol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
out_torch
,
torch
.
from_numpy
(
out_librosa
),
atol
=
1e-5
))
# test mel spectrogram
# test mel spectrogram
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
sr
=
sample_rate
,
window
=
torch
.
hann_window
,
melspect_transform
=
torchaudio
.
transforms
.
MelSpectrogram
(
hop
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
)
sample_rate
=
sample_rate
,
window_fn
=
torch
.
hann_window
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
)
librosa_mel
=
librosa
.
feature
.
melspectrogram
(
y
=
sound_librosa
,
sr
=
sample_rate
,
librosa_mel
=
librosa
.
feature
.
melspectrogram
(
y
=
sound_librosa
,
sr
=
sample_rate
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
n_mels
=
n_mels
,
htk
=
True
,
norm
=
None
)
htk
=
True
,
norm
=
None
)
librosa_mel_tensor
=
torch
.
from_numpy
(
librosa_mel
)
librosa_mel_tensor
=
torch
.
from_numpy
(
librosa_mel
)
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
.
t
()
torch_mel
=
melspect_transform
(
sound
).
squeeze
().
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch_mel
.
type
(
librosa_mel_tensor
.
dtype
),
librosa_mel_tensor
,
atol
=
5e-3
))
self
.
assertTrue
(
torch
.
allclose
(
torch_mel
.
type
(
librosa_mel_tensor
.
dtype
),
librosa_mel_tensor
,
atol
=
5e-3
))
# test s2db
# test s2db
db_transform
=
torchaudio
.
transforms
.
SpectrogramToDB
(
'power'
,
80.
)
db_transform
=
torchaudio
.
transforms
.
SpectrogramToDB
(
"power"
,
80.
)
db_torch
=
db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
()
db_torch
=
db_transform
(
spect_transform
(
sound
)).
squeeze
().
cpu
().
t
()
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
out_librosa
)
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
out_librosa
)
self
.
assertTrue
(
torch
.
allclose
(
db_torch
,
torch
.
from_numpy
(
db_librosa
),
atol
=
5e-3
))
self
.
assertTrue
(
torch
.
allclose
(
db_torch
,
torch
.
from_numpy
(
db_librosa
),
atol
=
5e-3
))
db_torch
=
db_transform
(
melspect_transform
(
sound
)).
squeeze
().
cpu
()
.
t
()
db_torch
=
db_transform
(
melspect_transform
(
sound
)).
squeeze
().
cpu
()
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
librosa_mel
)
db_librosa
=
librosa
.
core
.
spectrum
.
power_to_db
(
librosa_mel
)
db_librosa_tensor
=
torch
.
from_numpy
(
db_librosa
)
db_librosa_tensor
=
torch
.
from_numpy
(
db_librosa
)
self
.
assertTrue
(
torch
.
allclose
(
db_torch
.
type
(
db_librosa_tensor
.
dtype
),
db_librosa_tensor
,
atol
=
5e-3
))
self
.
assertTrue
(
torch
.
allclose
(
db_torch
.
type
(
db_librosa_tensor
.
dtype
),
db_librosa_tensor
,
atol
=
5e-3
))
# test MFCC
# test MFCC
melkwargs
=
{
'hop'
:
hop_length
,
'n_fft'
:
n_fft
}
melkwargs
=
{
'hop
_length
'
:
hop_length
,
'n_fft'
:
n_fft
}
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
r
=
sample_rate
,
mfcc_transform
=
torchaudio
.
transforms
.
MFCC
(
s
ample_rate
=
sample_rate
,
n_mfcc
=
n_mfcc
,
n_mfcc
=
n_mfcc
,
norm
=
'ortho'
,
norm
=
'ortho'
,
melkwargs
=
melkwargs
)
melkwargs
=
melkwargs
)
...
@@ -271,7 +203,7 @@ class Tester(unittest.TestCase):
...
@@ -271,7 +203,7 @@ class Tester(unittest.TestCase):
librosa_mfcc
=
scipy
.
fftpack
.
dct
(
db_librosa
,
axis
=
0
,
type
=
2
,
norm
=
'ortho'
)[:
n_mfcc
]
librosa_mfcc
=
scipy
.
fftpack
.
dct
(
db_librosa
,
axis
=
0
,
type
=
2
,
norm
=
'ortho'
)[:
n_mfcc
]
librosa_mfcc_tensor
=
torch
.
from_numpy
(
librosa_mfcc
)
librosa_mfcc_tensor
=
torch
.
from_numpy
(
librosa_mfcc
)
torch_mfcc
=
mfcc_transform
(
sound
).
squeeze
().
cpu
()
.
t
()
torch_mfcc
=
mfcc_transform
(
sound
).
squeeze
().
cpu
()
self
.
assertTrue
(
torch
.
allclose
(
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
))
self
.
assertTrue
(
torch
.
allclose
(
torch_mfcc
.
type
(
librosa_mfcc_tensor
.
dtype
),
librosa_mfcc_tensor
,
atol
=
5e-3
))
...
@@ -308,27 +240,27 @@ class Tester(unittest.TestCase):
...
@@ -308,27 +240,27 @@ class Tester(unittest.TestCase):
def
test_resample_size
(
self
):
def
test_resample_size
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
sound
,
sample_rate
=
torchaudio
.
load
(
input_path
)
waveform
,
sample_rate
=
torchaudio
.
load
(
input_path
)
upsample_rate
=
sample_rate
*
2
upsample_rate
=
sample_rate
*
2
downsample_rate
=
sample_rate
//
2
downsample_rate
=
sample_rate
//
2
invalid_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'foo'
)
invalid_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'foo'
)
self
.
assertRaises
(
ValueError
,
invalid_resample
,
sound
)
self
.
assertRaises
(
ValueError
,
invalid_resample
,
waveform
)
upsample_resample
=
torchaudio
.
transforms
.
Resample
(
upsample_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
upsample_rate
,
resampling_method
=
'sinc_interpolation'
)
sample_rate
,
upsample_rate
,
resampling_method
=
'sinc_interpolation'
)
up_sampled
=
upsample_resample
(
sound
)
up_sampled
=
upsample_resample
(
waveform
)
# we expect the upsampled signal to have twice as many samples
# we expect the upsampled signal to have twice as many samples
self
.
assertTrue
(
up_sampled
.
size
(
-
1
)
==
sound
.
size
(
-
1
)
*
2
)
self
.
assertTrue
(
up_sampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
*
2
)
downsample_resample
=
torchaudio
.
transforms
.
Resample
(
downsample_resample
=
torchaudio
.
transforms
.
Resample
(
sample_rate
,
downsample_rate
,
resampling_method
=
'sinc_interpolation'
)
sample_rate
,
downsample_rate
,
resampling_method
=
'sinc_interpolation'
)
down_sampled
=
downsample_resample
(
sound
)
down_sampled
=
downsample_resample
(
waveform
)
# we expect the downsampled signal to have half as many samples
# we expect the downsampled signal to have half as many samples
self
.
assertTrue
(
down_sampled
.
size
(
-
1
)
==
sound
.
size
(
-
1
)
//
2
)
self
.
assertTrue
(
down_sampled
.
size
(
-
1
)
==
waveform
.
size
(
-
1
)
//
2
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchaudio/functional.py
View file @
b29a4639
This diff is collapsed.
Click to expand it.
torchaudio/transforms.py
View file @
b29a4639
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment