Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
740c5a86
Commit
740c5a86
authored
Aug 19, 2017
by
David Pollack
Browse files
update save/load
.gitignore _ext/
parent
ecb538df
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
113 additions
and
26 deletions
+113
-26
.gitignore
.gitignore
+2
-1
test/test.py
test/test.py
+77
-16
torchaudio/__init__.py
torchaudio/__init__.py
+34
-9
No files found.
.gitignore
View file @
740c5a86
...
@@ -3,8 +3,9 @@ __pycache__/
...
@@ -3,8 +3,9 @@ __pycache__/
*.py[cod]
*.py[cod]
*$py.class
*$py.class
# C extensions
# C extensions
/ folders
*.so
*.so
_ext/
# Distribution / packaging
# Distribution / packaging
.Python
.Python
...
...
test/test.py
View file @
740c5a86
import
unittest
import
torch
import
torch.nn
as
nn
import
torchaudio
import
torchaudio
import
os
class
Test_LoadSave
(
unittest
.
TestCase
):
test_dirpath
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
test_filepath
=
os
.
path
.
join
(
test_dirpath
,
"steam-train-whistle-daniel_simon.mp3"
)
def
test_load
(
self
):
# check normal loading
x
,
sr
=
torchaudio
.
load
(
self
.
test_filepath
)
self
.
assertEqual
(
sr
,
44100
)
self
.
assertEqual
(
x
.
size
(),
(
278756
,
2
))
# check normalizing
x
,
sr
=
torchaudio
.
load
(
self
.
test_filepath
,
normalization
=
True
)
self
.
assertTrue
(
x
.
min
()
>=
-
1.0
)
self
.
assertTrue
(
x
.
max
()
<=
1.0
)
# check raising errors
with
self
.
assertRaises
(
OSError
):
torchaudio
.
load
(
"file-does-not-exist.mp3"
)
with
self
.
assertRaises
(
OSError
):
tdir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
self
.
test_dirpath
),
"torchaudio"
)
torchaudio
.
load
(
tdir
)
def
test_save
(
self
):
# load signal
x
,
sr
=
torchaudio
.
load
(
self
.
test_filepath
)
# check save
new_filepath
=
os
.
path
.
join
(
self
.
test_dirpath
,
"test.wav"
)
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
self
.
assertTrue
(
os
.
path
.
isfile
(
new_filepath
))
os
.
unlink
(
new_filepath
)
# check automatic normalization
x
/=
1
<<
31
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
self
.
assertTrue
(
os
.
path
.
isfile
(
new_filepath
))
os
.
unlink
(
new_filepath
)
# test save 1d tensor
x
=
x
[:,
0
]
# get mono signal
x
.
squeeze_
()
# remove channel dim
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
self
.
assertTrue
(
os
.
path
.
isfile
(
new_filepath
))
os
.
unlink
(
new_filepath
)
# don't allow invalid sizes as inputs
with
self
.
assertRaises
(
ValueError
):
x
.
unsqueeze_
(
0
)
# N x L not L x N
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
with
self
.
assertRaises
(
ValueError
):
x
.
squeeze_
()
x
.
unsqueeze_
(
1
)
x
.
unsqueeze_
(
0
)
# 1 x L x 1
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
# automatically convert sr from floating point to int
x
.
squeeze_
(
0
)
torchaudio
.
save
(
new_filepath
,
x
,
float
(
sr
))
self
.
assertTrue
(
os
.
path
.
isfile
(
new_filepath
))
os
.
unlink
(
new_filepath
)
# don't allow uneven integers
with
self
.
assertRaises
(
TypeError
):
torchaudio
.
save
(
new_filepath
,
x
,
float
(
sr
)
+
0.5
)
self
.
assertTrue
(
os
.
path
.
isfile
(
new_filepath
))
os
.
unlink
(
new_filepath
)
# don't save to folders that don't exist
with
self
.
assertRaises
(
OSError
):
new_filepath
=
os
.
path
.
join
(
self
.
test_dirpath
,
"no-path"
,
"test.wav"
)
torchaudio
.
save
(
new_filepath
,
x
,
sr
)
x
,
sample_rate
=
torchaudio
.
load
(
"steam-train-whistle-daniel_simon.mp3"
)
if
__name__
==
'__main__'
:
print
(
sample_rate
)
unittest
.
main
()
print
(
x
.
size
())
print
(
x
[
10000
])
print
(
x
.
min
(),
x
.
max
())
print
(
x
.
mean
(),
x
.
std
())
x
,
sample_rate
=
torchaudio
.
load
(
"steam-train-whistle-daniel_simon.mp3"
,
out
=
torch
.
LongTensor
())
print
(
sample_rate
)
print
(
x
.
size
())
print
(
x
[
10000
])
print
(
x
.
min
(),
x
.
max
())
torchaudio/__init__.py
View file @
740c5a86
...
@@ -14,27 +14,52 @@ def check_input(src):
...
@@ -14,27 +14,52 @@ def check_input(src):
if
not
src
.
__module__
==
'torch'
:
if
not
src
.
__module__
==
'torch'
:
raise
TypeError
(
'Expected a CPU based tensor, got %s'
%
type
(
src
))
raise
TypeError
(
'Expected a CPU based tensor, got %s'
%
type
(
src
))
def
load
(
filepath
,
out
=
None
,
normalization
=
None
):
def
load
(
filename
,
out
=
None
):
# check if valid file
if
not
os
.
path
.
isfile
(
filepath
):
raise
OSError
(
"{} not found or is a directory"
.
format
(
filepath
))
# initialize output tensor
if
out
is
not
None
:
if
out
is
not
None
:
check_input
(
out
)
check_input
(
out
)
else
:
else
:
out
=
torch
.
FloatTensor
()
out
=
torch
.
FloatTensor
()
# load audio signal
typename
=
type
(
out
).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
out
).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
th_sox
,
'libthsox_{}_read_audio_file'
.
format
(
typename
))
func
=
getattr
(
th_sox
,
'libthsox_{}_read_audio_file'
.
format
(
typename
))
sample_rate_p
=
ffi
.
new
(
'int*'
)
sample_rate_p
=
ffi
.
new
(
'int*'
)
func
(
str
(
file
name
).
encode
(
"
ascii
"
),
out
,
sample_rate_p
)
func
(
str
(
file
path
).
encode
(
"
utf-8
"
),
out
,
sample_rate_p
)
sample_rate
=
sample_rate_p
[
0
]
sample_rate
=
sample_rate_p
[
0
]
# normalize if needed
if
isinstance
(
normalization
,
bool
)
and
normalization
:
out
/=
1
<<
31
# assuming 16-bit depth
elif
isinstance
(
normalization
,
(
float
,
int
)):
out
/=
normalization
# normalize with custom value
return
out
,
sample_rate
return
out
,
sample_rate
def
save
(
filepath
,
src
,
sample_rate
):
def
save
(
filepath
,
src
,
sample_rate
):
filename
,
extension
=
os
.
path
.
splitext
(
filepath
)
# check if save directory exists
if
type
(
sample_rate
)
!=
int
:
abs_dirpath
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
filepath
))
if
not
os
.
path
.
isdir
(
abs_dirpath
):
raise
OSError
(
"Directory does not exist: {}"
.
format
(
abs_dirpath
))
# Check/Fix shape of source data
if
len
(
src
.
size
())
==
1
:
# 1d tensors as assumed to be mono signals
src
.
unsqueeze_
(
1
)
elif
len
(
src
.
size
())
>
2
or
src
.
size
(
1
)
>
2
:
raise
ValueError
(
"Expected format (L x N), N = 1 or 2, but found {}"
.
format
(
src
.
size
()))
# check if sample_rate is an integer
if
not
isinstance
(
sample_rate
,
int
):
if
int
(
sample_rate
)
==
sample_rate
:
sample_rate
=
int
(
sample_rate
)
else
:
raise
TypeError
(
'Sample rate should be a integer'
)
raise
TypeError
(
'Sample rate should be a integer'
)
# programs such as librosa normalize the signal, unnormalize if detected
if
src
.
min
()
>=
-
1.0
and
src
.
max
()
<=
1.0
:
src
=
src
*
(
1
<<
31
)
# assuming 16-bit depth
src
=
src
.
long
()
# save data to file
filename
,
extension
=
os
.
path
.
splitext
(
filepath
)
check_input
(
src
)
check_input
(
src
)
typename
=
type
(
src
).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
src
).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
th_sox
,
'libthsox_{}_write_audio_file'
.
format
(
typename
))
func
=
getattr
(
th_sox
,
'libthsox_{}_write_audio_file'
.
format
(
typename
))
func
(
bytes
(
filepath
,
"utf-8"
),
src
,
bytes
(
extension
[
1
:],
"utf-8"
),
sample_rate
)
func
(
bytes
(
filepath
,
"ascii"
),
src
,
extension
[
1
:],
sample_rate
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment