Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b40aee5a
Unverified
Commit
b40aee5a
authored
Sep 17, 2021
by
nateanl
Committed by
GitHub
Sep 17, 2021
Browse files
Refactor batch consistency test in transforms (#1772)
parent
0f822179
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
132 deletions
+87
-132
test/torchaudio_unittest/transforms/batch_consistency_test.py
.../torchaudio_unittest/transforms/batch_consistency_test.py
+87
-132
No files found.
test/torchaudio_unittest/transforms/batch_consistency_test.py
View file @
b40aee5a
"""Test numerical consistency among single input and batched input."""
import
torch
import
torchaudio
from
parameterized
import
parameterized
from
torchaudio
import
transforms
as
T
from
torchaudio_unittest
import
common_utils
class
TestTransforms
(
common_utils
.
TorchaudioTestCase
):
"""Test suite for classes defined in `transforms` module"""
backend
=
'default'
"""Test suite for classes defined in `transforms` module"""
def
test_batch_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
2
,
6
,
201
))
def
assert_batch_consistency
(
self
,
transform
,
batch
,
*
args
,
atol
=
1e-8
,
rtol
=
1e-5
,
seed
=
42
,
**
kwargs
):
n
=
batch
.
size
(
0
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
AmplitudeToDB
()(
spec
).
repeat
(
3
,
1
,
1
)
# Compute items separately, then batch the result
torch
.
random
.
manual_seed
(
seed
)
items_input
=
batch
.
clone
()
items_result
=
torch
.
stack
([
transform
(
items_input
[
i
],
*
args
,
**
kwargs
)
for
i
in
range
(
n
)
])
# Batch then transform
computed
=
torchaudio
.
transforms
.
AmplitudeToDB
()(
spec
.
repeat
(
3
,
1
,
1
))
# Batch the input and run
torch
.
random
.
manual_seed
(
seed
)
batch_input
=
batch
.
clone
()
batch_result
=
transform
(
batch_input
,
*
args
,
**
kwargs
)
self
.
assertEqual
(
computed
,
expected
)
self
.
assertEqual
(
items_input
,
batch_input
,
rtol
=
rtol
,
atol
=
atol
)
self
.
assertEqual
(
items_result
,
batch_result
,
rtol
=
rtol
,
atol
=
atol
)
def
test_batch_Resample
(
self
):
waveform
=
torch
.
randn
(
2
,
2786
)
def
test_batch_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
3
,
2
,
6
,
201
))
transform
=
T
.
AmplitudeToDB
()
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Resample
()(
waveform
).
repeat
(
3
,
1
,
1
)
self
.
assert_batch_consistency
(
transform
,
spec
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Resample
()(
waveform
.
repeat
(
3
,
1
,
1
))
def
test_batch_Resample
(
self
):
waveform
=
torch
.
randn
(
3
,
2
,
2786
)
transform
=
T
.
Resample
()
self
.
assert
Equal
(
computed
,
expected
)
self
.
assert
_batch_consistency
(
transform
,
waveform
)
def
test_batch_MelScale
(
self
):
specgram
=
torch
.
randn
(
2
,
201
,
256
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
MelScale
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MelScale
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
specgram
=
torch
.
randn
(
3
,
2
,
201
,
256
)
transform
=
T
.
MelScale
()
# shape = (3, 2, 128, 256)
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
specgram
)
def
test_batch_InverseMelScale
(
self
):
n_mels
=
32
n_stft
=
5
mel_spec
=
torch
.
randn
(
2
,
n_mels
,
32
)
**
2
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
InverseMelScale
(
n_stft
,
n_mels
)(
mel_spec
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, n_mels, 32)
mel_spec
=
torch
.
randn
(
3
,
2
,
n_mels
,
32
)
**
2
transform
=
T
.
InverseMelScale
(
n_stft
,
n_mels
)
# Because InverseMelScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self
.
assert
Equal
(
computed
,
expected
,
atol
=
1.0
,
rtol
=
1e-5
)
self
.
assert
_batch_consistency
(
transform
,
mel_spec
,
atol
=
1.0
,
rtol
=
1e-5
)
def
test_batch_compute_deltas
(
self
):
specgram
=
torch
.
randn
(
2
,
31
,
2786
)
specgram
=
torch
.
randn
(
3
,
2
,
31
,
2786
)
transform
=
T
.
ComputeDeltas
()
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
ComputeDeltas
()(
specgram
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
ComputeDeltas
()(
specgram
.
repeat
(
3
,
1
,
1
,
1
))
# shape = (3, 2, 201, 1394)
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
specgram
)
def
test_batch_mulaw
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
# Single then transform then batch
waveform_encoded
=
torchaudio
.
transforms
.
MuLawEncoding
()(
waveform
)
expected
=
waveform_encoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
expected
=
[
T
.
MuLawEncoding
()(
waveform
[
i
])
for
i
in
range
(
3
)]
expected
=
torch
.
stack
(
expected
)
# Batch then transform
waveform_batched
=
waveform
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
computed
=
torchaudio
.
transforms
.
MuLawEncoding
()(
waveform_batched
)
computed
=
T
.
MuLawEncoding
()(
waveform
)
# shape = (3, 2, 201, 1394)
self
.
assertEqual
(
computed
,
expected
)
# Single then transform then batch
waveform
_decoded
=
torchaudio
.
transforms
.
MuLawDecoding
()(
waveform_encoded
)
expected
=
waveform_decoded
.
unsqueeze
(
0
).
repeat
(
3
,
1
,
1
)
expected
_decoded
=
[
T
.
MuLawDecoding
()(
expected
[
i
])
for
i
in
range
(
3
)]
expected
_decoded
=
torch
.
stack
(
expected_decoded
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MuLawDecoding
()(
computed
)
computed
_decoded
=
T
.
MuLawDecoding
()(
computed
)
# shape = (3, 2, 201, 1394)
self
.
assertEqual
(
computed
,
expect
ed
)
self
.
assertEqual
(
computed
_decoded
,
expected_decod
ed
)
def
test_batch_spectrogram
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Spectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
Spectrogram
()
# Batch then transform
computed
=
torchaudio
.
transforms
.
Spectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
waveform
)
def
test_batch_inverse_spectrogram
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
transform
=
torchaudio
.
transforms
.
Spectrogram
(
power
=
None
)(
waveform
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
InverseSpectrogram
()(
transform
).
repeat
(
3
,
1
,
1
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
specgram
=
common_utils
.
get_spectrogram
(
waveform
,
n_fft
=
400
)
specgram
=
specgram
.
reshape
(
3
,
2
,
specgram
.
shape
[
-
2
],
specgram
.
shape
[
-
1
])
transform
=
T
.
InverseSpectrogram
(
n_fft
=
400
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
InverseSpectrogram
()(
transform
.
repeat
(
3
,
1
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
specgram
)
def
test_batch_melspectrogram
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
MelSpectrogram
()
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
MelSpectrogram
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MelSpectrogram
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
waveform
)
def
test_batch_mfcc
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
MFCC
()
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
MFCC
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
MFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
self
.
assert_batch_consistency
(
transform
,
waveform
,
atol
=
1e-4
,
rtol
=
1e-5
)
def
test_batch_lfcc
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
LFCC
()(
waveform
).
repeat
(
3
,
1
,
1
,
1
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
LFCC
()
# Batch then transform
computed
=
torchaudio
.
transforms
.
LFCC
()(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-4
,
rtol
=
1e-5
)
self
.
assert_batch_consistency
(
transform
,
waveform
,
atol
=
1e-4
,
rtol
=
1e-5
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_batch_TimeStretch
(
self
,
test_pseudo_complex
):
rate
=
2
num_freq
=
1025
num_frames
=
400
batch
=
3
spec
=
torch
.
randn
(
num_freq
,
num_frames
,
dtype
=
torch
.
complex64
)
pattern
=
[
3
,
1
,
1
,
1
]
spec
=
torch
.
randn
(
batch
,
num_freq
,
num_frames
,
dtype
=
torch
.
complex64
)
if
test_pseudo_complex
:
spec
=
torch
.
view_as_real
(
spec
)
pattern
+=
[
1
]
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
TimeStretch
(
fixed_rate
=
rate
,
n_freq
=
num_freq
,
hop_length
=
512
,
)(
spec
).
repeat
(
*
pattern
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
TimeStretch
(
transform
=
T
.
TimeStretch
(
fixed_rate
=
rate
,
n_freq
=
num_freq
,
hop_length
=
512
,
)
(
spec
.
repeat
(
*
pattern
))
hop_length
=
512
)
self
.
assert
Equal
(
computed
,
expected
,
atol
=
1e-5
,
rtol
=
1e-5
)
self
.
assert
_batch_consistency
(
transform
,
spec
,
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_batch_Fade
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
fade_in_len
=
3000
fade_out_len
=
3000
transform
=
T
.
Fade
(
fade_in_len
,
fade_out_len
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Fade
(
fade_in_len
,
fade_out_len
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
waveform
)
def
test_batch_Vol
(
self
):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
2
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
8000
,
duration
=
1
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
Vol
(
gain
=
1.1
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
Vol
(
gain
=
1.1
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
waveform
)
def
test_batch_spectral_centroid
(
self
):
sample_rate
=
44100
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
SpectralCentroid
(
sample_rate
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
SpectralCentroid
(
sample_rate
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
waveform
)
def
test_batch_pitch_shift
(
self
):
sample_rate
=
8000
n_steps
=
-
2
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
6
)
waveform
=
waveform
.
reshape
(
3
,
2
,
-
1
)
transform
=
T
.
PitchShift
(
sample_rate
,
n_steps
,
n_fft
=
400
)
# Single then transform then batch
expected
=
torchaudio
.
transforms
.
PitchShift
(
sample_rate
,
n_steps
,
n_fft
=
400
)(
waveform
).
repeat
(
3
,
1
,
1
)
# Batch then transform
computed
=
torchaudio
.
transforms
.
PitchShift
(
sample_rate
,
n_steps
,
n_fft
=
400
)(
waveform
.
repeat
(
3
,
1
,
1
))
self
.
assertEqual
(
computed
,
expected
)
self
.
assert_batch_consistency
(
transform
,
waveform
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment