Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
bc6b7f97
"src/vscode:/vscode.git/clone" did not exist on "00f95b9755718aabb65456e791b8408526ae6e76"
Unverified
Commit
bc6b7f97
authored
May 20, 2020
by
moto
Committed by
GitHub
May 20, 2020
Browse files
Adopt PyTorch's test util to transforms test (#652)
parent
ac7c052f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
test/test_transforms.py
test/test_transforms.py
+8
-7
No files found.
test/test_transforms.py
View file @
bc6b7f97
...
@@ -2,6 +2,7 @@ import math
...
@@ -2,6 +2,7 @@ import math
import
unittest
import
unittest
import
torch
import
torch
from
torch.testing._internal.common_utils
import
TestCase
import
torchaudio
import
torchaudio
import
torchaudio.transforms
as
transforms
import
torchaudio.transforms
as
transforms
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
...
@@ -9,7 +10,7 @@ import torchaudio.functional as F
...
@@ -9,7 +10,7 @@ import torchaudio.functional as F
import
common_utils
import
common_utils
class
Tester
(
unittest
.
TestCase
):
class
Tester
(
TestCase
):
# create a sinewave signal for testing
# create a sinewave signal for testing
sample_rate
=
16000
sample_rate
=
16000
...
@@ -49,7 +50,7 @@ class Tester(unittest.TestCase):
...
@@ -49,7 +50,7 @@ class Tester(unittest.TestCase):
mag_to_db_torch
=
mag_to_db_transform
(
torch
.
abs
(
waveform
))
mag_to_db_torch
=
mag_to_db_transform
(
torch
.
abs
(
waveform
))
power_to_db_torch
=
power_to_db_transform
(
torch
.
pow
(
waveform
,
2
))
power_to_db_torch
=
power_to_db_transform
(
torch
.
pow
(
waveform
,
2
))
torch
.
testing
.
assert_allclose
(
mag_to_db_torch
,
power_to_db_torch
)
self
.
assertEqual
(
mag_to_db_torch
,
power_to_db_torch
)
def
test_melscale_load_save
(
self
):
def
test_melscale_load_save
(
self
):
specgram
=
torch
.
ones
(
1
,
1000
,
100
)
specgram
=
torch
.
ones
(
1
,
1000
,
100
)
...
@@ -63,7 +64,7 @@ class Tester(unittest.TestCase):
...
@@ -63,7 +64,7 @@ class Tester(unittest.TestCase):
fb_copy
=
melscale_transform_copy
.
fb
fb_copy
=
melscale_transform_copy
.
fb
self
.
assertEqual
(
fb_copy
.
size
(),
(
1000
,
128
))
self
.
assertEqual
(
fb_copy
.
size
(),
(
1000
,
128
))
torch
.
testing
.
assert_allclose
(
fb
,
fb_copy
)
self
.
assertEqual
(
fb
,
fb_copy
)
def
test_melspectrogram_load_save
(
self
):
def
test_melspectrogram_load_save
(
self
):
waveform
=
self
.
waveform
.
float
()
waveform
=
self
.
waveform
.
float
()
...
@@ -79,10 +80,10 @@ class Tester(unittest.TestCase):
...
@@ -79,10 +80,10 @@ class Tester(unittest.TestCase):
fb
=
mel_spectrogram_transform
.
mel_scale
.
fb
fb
=
mel_spectrogram_transform
.
mel_scale
.
fb
fb_copy
=
mel_spectrogram_transform_copy
.
mel_scale
.
fb
fb_copy
=
mel_spectrogram_transform_copy
.
mel_scale
.
fb
torch
.
testing
.
assert_allclose
(
window
,
window_copy
)
self
.
assertEqual
(
window
,
window_copy
)
# the default for n_fft = 400 and n_mels = 128
# the default for n_fft = 400 and n_mels = 128
self
.
assertEqual
(
fb_copy
.
size
(),
(
201
,
128
))
self
.
assertEqual
(
fb_copy
.
size
(),
(
201
,
128
))
torch
.
testing
.
assert_allclose
(
fb
,
fb_copy
)
self
.
assertEqual
(
fb
,
fb_copy
)
def
test_mel2
(
self
):
def
test_mel2
(
self
):
top_db
=
80.
top_db
=
80.
...
@@ -205,7 +206,7 @@ class Tester(unittest.TestCase):
...
@@ -205,7 +206,7 @@ class Tester(unittest.TestCase):
computed_transform
=
transform
(
specgram
)
computed_transform
=
transform
(
specgram
)
computed_functional
=
F
.
compute_deltas
(
specgram
,
win_length
=
win_length
)
computed_functional
=
F
.
compute_deltas
(
specgram
,
win_length
=
win_length
)
torch
.
testing
.
assert_allclose
(
computed_functional
,
computed_transform
,
atol
=
atol
,
rtol
=
rtol
)
self
.
assertEqual
(
computed_functional
,
computed_transform
,
atol
=
atol
,
rtol
=
rtol
)
def
test_compute_deltas_twochannel
(
self
):
def
test_compute_deltas_twochannel
(
self
):
specgram
=
torch
.
tensor
([
1.
,
2.
,
3.
,
4.
]).
repeat
(
1
,
2
,
1
)
specgram
=
torch
.
tensor
([
1.
,
2.
,
3.
,
4.
]).
repeat
(
1
,
2
,
1
)
...
@@ -214,7 +215,7 @@ class Tester(unittest.TestCase):
...
@@ -214,7 +215,7 @@ class Tester(unittest.TestCase):
transform
=
transforms
.
ComputeDeltas
(
win_length
=
3
)
transform
=
transforms
.
ComputeDeltas
(
win_length
=
3
)
computed
=
transform
(
specgram
)
computed
=
transform
(
specgram
)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
assert
computed
.
shape
==
expected
.
shape
,
(
computed
.
shape
,
expected
.
shape
)
torch
.
testing
.
assert_allclose
(
computed
,
expected
,
atol
=
1e-6
,
rtol
=
1e-8
)
self
.
assertEqual
(
computed
,
expected
,
atol
=
1e-6
,
rtol
=
1e-8
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment