Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b40aee5a
Unverified
Commit
b40aee5a
authored
Sep 17, 2021
by
nateanl
Committed by
GitHub
Sep 17, 2021
Browse files
Refactor batch consistency test in transforms (#1772)
parent
0f822179
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
132 deletions
+87
-132
test/torchaudio_unittest/transforms/batch_consistency_test.py
.../torchaudio_unittest/transforms/batch_consistency_test.py
+87
-132
No files found.
test/torchaudio_unittest/transforms/batch_consistency_test.py
View file @
b40aee5a
"""Test numerical consistency among single input and batched input."""
"""Test numerical consistency among single input and batched input."""
import
torch
import
torch
import
torchaudio
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio
import
transforms
as
T
from
torchaudio_unittest
import
common_utils
from
torchaudio_unittest
import
common_utils
class
TestTransforms
(
common_utils
.
TorchaudioTestCase
):
class
TestTransforms
(
common_utils
.
TorchaudioTestCase
):
"""Test suite for classes defined in `transforms` module"""
backend
=
'default'
backend
=
'default'
"""Test suite for classes defined in `transforms` module"""
def
assert_batch_consistency
(
def
test_batch_AmplitudeToDB
(
self
):
self
,
transform
,
batch
,
*
args
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
spec
=
torch
.
rand
((
2
,
6
,
201
))
**
kwargs
):
n
=
batch
.
size
(
0
)
# Single then transform then batch
# Compute items separately, then batch the result
expected
=
torchaudio
.
transforms
.
AmplitudeToDB
()(
spec
).
repeat
(
3
,
1
,
1
)
torch
.
random
.
manual_seed
(
seed
)
items_input
=
batch
.
clone
()
items_result
=
torch
.
stack
([
transform
(
items_input
[
i
],
*
args
,
**
kwargs
)
for
i
in
range
(
n
)
])
# Batch then transform
# Batch the input and run
computed
=
torchaudio
.
transforms
.
AmplitudeToDB
()(
spec
.
repeat
(
3
,
1
,
1
))
torch
.
random
.
manual_seed
(
seed
)
batch_input
=
batch
.
clone
()
batch_result
=
transform
(
batch_input
,
*
args
,
**
kwargs
)
self
.
assertEqual
(
computed
,
expected
)
self
.
assertEqual
(
items_input
,
batch_input
,
rtol
=
rtol
,
atol
=
atol
)
self
.
assertEqual
(
items_result
,
batch_result
,
rtol
=
rtol
,
atol
=
atol
)
def
test_batch_Resample
(
self
):
def
test_batch_AmplitudeToDB
(
self
):
waveform
=
torch
.
randn
(
2
,
2786
)
spec
=
torch
.
rand
((
3
,
2
,
6
,
201
))
transform
=
T
.
AmplitudeToDB
()
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
spec
)
expected
=
torchaudio
.
transforms
.
Resample
()(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
def
test_batch_Resample
(
self
):
computed
=
torchaudio
.
transforms
.
Resample
()(
waveform
.
repeat
(
3
,
1
,
1
))
waveform
=
torch
.
randn
(
3
,
2
,
2786
)
transform
=
T
.
Resample
()
self
.
assert
Equal
(
computed
,
expected
)
self
.
assert
_batch_consistency
(
transform
,
waveform
)
def
test_batch_MelScale
(
self
):
def
test_batch_MelScale
(
self
):
specgram
=
torch
.
randn
(
2
,
201
,
256
)
specgram
=
torch
.
randn
(
3
,
2
,
201
,
256
)
transform
=
T
.
MelScale
()
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
MelScale
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MelScale
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 128, 256)
self
.
assert_batch_consistency
(
transform
,
specgram
)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_InverseMelScale
(
self
):
def
test_batch_InverseMelScale
(
self
):
n_mels
=
32
n_mels
=
32
n_stft
=
5
n_stft
=
5
mel_spec
=
torch
.
randn
(
2
,
n_mels
,
32
)
**
2
mel_spec
=
torch
.
randn
(
3
,
2
,
n_mels
,
32
)
**
2
transform
=
T
.
InverseMelScale
(
n_stft
,
n_mels
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, n_mels, 32)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
# exactly same result. For this reason, tolerance is very relaxed here.
self
.
assert
Equal
(
computed
,
expected
,
atol
=
1.0
,
rtol
=
1e-5
)
self
.
assert
_batch_consistency
(
transform
,
mel_spec
,
atol
=
1.0
,
rtol
=
1e-5
)
def
test_batch_compute_deltas
(
self
):
def
test_batch_compute_deltas
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
specgram
=
torch
.
randn
(
3
,
2
,
31
,
2786
)
transform
=
T
.
ComputeDeltas
()
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
specgram
)
expected
=
torchaudio
.
transforms
.
ComputeDeltas
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
ComputeDeltas
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_mulaw
(
self
):
def
test_batch_mulaw
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
# Single then transform then batch
# Single then transform then batch
waveform_encoded
=
torchaudio
.
transforms
.
MuLawEncoding
()(
waveform
)
expected
=
[
T
.
MuLawEncoding
()(
waveform
[
i
])
for
i
in
range
(
3
)]
expected
=
waveform_encoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
expected
=
torch
.
stack
(
expected
)
# Batch then transform
# Batch then transform
waveform_batched
=
waveform
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
computed
=
T
.
MuLawEncoding
()(
waveform
)
computed
=
torchaudio
.
transforms
.
MuLawEncoding
()(
waveform_batched
)
# shape = (3, 2, 201, 1394)
# shape = (3, 2, 201, 1394)
self
.
assertEqual
(
computed
,
expected
)
self
.
assertEqual
(
computed
,
expected
)
# Single then transform then batch
# Single then transform then batch
waveform
_decoded
=
torchaudio
.
transforms
.
MuLawDecoding
()(
waveform_encoded
)
expected
_decoded
=
[
T
.
MuLawDecoding
()(
expected
[
i
])
for
i
in
range
(
3
)]
expected
=
waveform_decoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
expected
_decoded
=
torch
.
stack
(
expected_decoded
)
# Batch then transform
# Batch then transform
computed
=
torchaudio
.
transforms
.
MuLawDecoding
()(
computed
)
computed
_decoded
=
T
.
MuLawDecoding
()(
computed
)
# shape = (3, 2, 201, 1394)
# shape = (3, 2, 201, 1394)
self
.
assertEqual
(
computed
,
expect
ed
)
self
.
assertEqual
(
computed
_decoded
,
expected_decod
ed
)
def
test_batch_spectrogram
(
self
):
def
test_batch_spectrogram
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
# Single then transform then batch
transform
=
T
.
Spectrogram
()
expected
=
torchaudio
.
transforms
.
Spectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
self
.
assert_batch_consistency
(
transform
,
waveform
)
computed
=
torchaudio
.
transforms
.
Spectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_inverse_spectrogram
(
self
):
def
test_batch_inverse_spectrogram
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
transform
=
torchaudio
.
transforms
.
Spectrogram
(
power
=
None
)(
waveform
)
specgram
=
common_utils
.
get_spectrogram
(
waveform
,
n_fft
=
400
)
specgram
=
specgram
.
reshape
(
3
,
2
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
])
# Single then transform then batch
transform
=
T
.
InverseSpectrogram
(
n_fft
=
400
)
expected
=
torchaudio
.
transforms
.
InverseSpectrogram
()(
transform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
self
.
assert_batch_consistency
(
transform
,
specgram
)
computed
=
torchaudio
.
transforms
.
InverseSpectrogram
()(
transform
.
repeat
(
3
,
1
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_melspectrogram
(
self
):
def
test_batch_melspectrogram
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
MelSpectrogram
()
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
waveform
)
expected
=
torchaudio
.
transforms
.
MelSpectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MelSpectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_mfcc
(
self
):
def
test_batch_mfcc
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
MFCC
()
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
waveform
,
atol
=
1e-4
,
rtol
=
1e-5
)
expected
=
torchaudio
.
transforms
.
MFCC
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
def
test_batch_lfcc
(
self
):
def
test_batch_lfcc
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
# Single then transform then batch
transform
=
T
.
LFCC
()
expected
=
torchaudio
.
transforms
.
LFCC
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
self
.
assert_batch_consistency
(
transform
,
waveform
,
atol
=
1e-4
,
rtol
=
1e-5
)
computed
=
torchaudio
.
transforms
.
LFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_batch_TimeStretch
(
self
,
test_pseudo_complex
):
def
test_batch_TimeStretch
(
self
,
test_pseudo_complex
):
rate
=
2
rate
=
2
num_freq
=
1025
num_freq
=
1025
num_frames
=
400
num_frames
=
400
batch
=
3
spec
=
torch
.
randn
(
num_freq
,
num_frames
,
dtype
=
torch
.
complex64
)
spec
=
torch
.
randn
(
batch
,
num_freq
,
num_frames
,
dtype
=
torch
.
complex64
)
pattern
=
[
3
,
1
,
1
,
1
]
if
test_pseudo_complex
:
if
test_pseudo_complex
:
spec
=
torch
.
view_as_real
(
spec
)
spec
=
torch
.
view_as_real
(
spec
)
pattern
+=
[
1
]
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
TimeStretch
(
fixed_rate
=
rate
,
n_freq
=
num_freq
,
hop_length
=
512
,
)(
spec
).
repeat
(
*
pattern
)
# Batch then transform
transform
=
T
.
TimeStretch
(
computed
=
torchaudio
.
transforms
.
TimeStretch
(
fixed_rate
=
rate
,
fixed_rate
=
rate
,
n_freq
=
num_freq
,
n_freq
=
num_freq
,
hop_length
=
512
,
hop_length
=
512
)
(
spec
.
repeat
(
*
pattern
))
)
self
.
assert
Equal
(
computed
,
expected
,
atol
=
1e-5
,
rtol
=
1e-5
)
self
.
assert
_batch_consistency
(
transform
,
spec
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_batch_Fade
(
self
):
def
test_batch_Fade
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
fade_in_len
=
3000
fade_in_len
=
3000
fade_out_len
=
3000
fade_out_len
=
3000
transform
=
T
.
Fade
(
fade_in_len
,
fade_out_len
)
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
waveform
)
expected
=
torchaudio
.
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_Vol
(
self
):
def
test_batch_Vol
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
Vol
(
gain
=
1.1
)
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
waveform
)
expected
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_spectral_centroid
(
self
):
def
test_batch_spectral_centroid
(
self
):
sample_rate
=
44100
sample_rate
=
44100
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
SpectralCentroid
(
sample_rate
)
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
waveform
)
expected
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
def
test_batch_pitch_shift
(
self
):
def
test_batch_pitch_shift
(
self
):
sample_rate
=
8000
sample_rate
=
8000
n_steps
=
-
2
n_steps
=
-
2
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
PitchShift
(
sample_rate
,
n_steps
,
n_fft
=
400
)
# Single then transform then batch
self
.
assert_batch_consistency
(
transform
,
waveform
)
expected
=
torchaudio
.
transforms
.
PitchShift
(
sample_rate
,
n_steps
,
n_fft
=
400
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
PitchShift
(
sample_rate
,
n_steps
,
n_fft
=
400
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment