Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
cba11009
Commit
cba11009
authored
Mar 12, 2018
by
Peter Goldsborough
Committed by
Soumith Chintala
Apr 25, 2018
Browse files
Added test to make sure loading and saving gives the same file
parent
8a41ecdc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
10 deletions
+20
-10
test/test.py
test/test.py
+20
-10
No files found.
test/test.py
View file @
cba11009
...
@@ -7,8 +7,8 @@ import os
...
@@ -7,8 +7,8 @@ import os
class
Test_LoadSave
(
unittest
.
TestCase
):
class
Test_LoadSave
(
unittest
.
TestCase
):
test_dirpath
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
test_dirpath
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
test_filepath
=
os
.
path
.
join
(
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
"assets"
,
test_dirpath
,
"assets"
,
"steam-train-whistle-daniel_simon.mp3"
)
"steam-train-whistle-daniel_simon.mp3"
)
def
test_load
(
self
):
def
test_load
(
self
):
# check normal loading
# check normal loading
...
@@ -16,7 +16,6 @@ class Test_LoadSave(unittest.TestCase):
...
@@ -16,7 +16,6 @@ class Test_LoadSave(unittest.TestCase):
self
.
assertEqual
(
sr
,
44100
)
self
.
assertEqual
(
sr
,
44100
)
self
.
assertEqual
(
x
.
size
(),
(
278756
,
2
))
self
.
assertEqual
(
x
.
size
(),
(
278756
,
2
))
self
.
assertGreater
(
x
.
sum
(),
0
)
self
.
assertGreater
(
x
.
sum
(),
0
)
print
# check normalizing
# check normalizing
x
,
sr
=
torchaudio
.
load
(
self
.
test_filepath
,
normalization
=
True
)
x
,
sr
=
torchaudio
.
load
(
self
.
test_filepath
,
normalization
=
True
)
...
@@ -28,8 +27,8 @@ class Test_LoadSave(unittest.TestCase):
...
@@ -28,8 +27,8 @@ class Test_LoadSave(unittest.TestCase):
torchaudio
.
load
(
"file-does-not-exist.mp3"
)
torchaudio
.
load
(
"file-does-not-exist.mp3"
)
with
self
.
assertRaises
(
OSError
):
with
self
.
assertRaises
(
OSError
):
tdir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
tdir
=
os
.
path
.
join
(
self
.
test_dirpath
),
"torchaudio"
)
os
.
path
.
dirname
(
self
.
test_dirpath
),
"torchaudio"
)
torchaudio
.
load
(
tdir
)
torchaudio
.
load
(
tdir
)
def
test_save
(
self
):
def
test_save
(
self
):
...
@@ -80,24 +79,35 @@ class Test_LoadSave(unittest.TestCase):
...
@@ -80,24 +79,35 @@ class Test_LoadSave(unittest.TestCase):
# don't save to folders that don't exist
# don't save to folders that don't exist
with
self
.
assertRaises
(
OSError
):
with
self
.
assertRaises
(
OSError
):
new_filepath
=
os
.
path
.
join
(
new_filepath
=
os
.
path
.
join
(
self
.
test_dirpath
,
"no-path"
,
self
.
test_dirpath
,
"no-path"
,
"test.wav"
)
"test.wav"
)
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
# save created file
# save created file
sinewave_filepath
=
os
.
path
.
join
(
sinewave_filepath
=
os
.
path
.
join
(
self
.
test_dirpath
,
"assets"
,
self
.
test_dirpath
,
"assets"
,
"sinewave.wav"
)
"sinewave.wav"
)
sr
=
16000
sr
=
16000
freq
=
440
freq
=
440
volume
=
0.3
volume
=
0.3
y
=
(
torch
.
cos
(
2
*
math
.
pi
*
torch
.
arange
(
0
,
4
*
sr
)
*
freq
/
sr
)).
float
()
y
=
(
torch
.
cos
(
2
*
math
.
pi
*
torch
.
arange
(
0
,
4
*
sr
)
*
freq
/
sr
)).
float
()
y
.
unsqueeze_
(
1
)
y
.
unsqueeze_
(
1
)
# y is between -1 and 1, so must scale
# y is between -1 and 1, so must scale
y
=
(
y
*
volume
*
2
**
31
).
long
()
y
=
(
y
*
volume
*
2
**
31
).
long
()
torchaudio
.
save
(
sinewave_filepath
,
y
,
sr
)
torchaudio
.
save
(
sinewave_filepath
,
y
,
sr
)
self
.
assertTrue
(
os
.
path
.
isfile
(
sinewave_filepath
))
self
.
assertTrue
(
os
.
path
.
isfile
(
sinewave_filepath
))
def
test_load_and_save_is_identity
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
tensor
,
sample_rate
=
torchaudio
.
load
(
input_path
)
output_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'test.wav'
)
torchaudio
.
save
(
output_path
,
tensor
,
sample_rate
)
tensor2
,
sample_rate2
=
torchaudio
.
load
(
output_path
)
self
.
assertTrue
(
tensor
.
allclose
(
tensor2
))
self
.
assertEqual
(
sample_rate
,
sample_rate2
)
os
.
unlink
(
output_path
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment