Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
7314b36d
Commit
7314b36d
authored
Aug 21, 2018
by
David Pollack
Committed by
Soumith Chintala
Aug 20, 2018
Browse files
allow loading with offsets and number of samples and saving specified bit precisions (#59)
parent
0b93ff06
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
13 deletions
+154
-13
test/test.py
test/test.py
+51
-0
torchaudio/__init__.py
torchaudio/__init__.py
+37
-4
torchaudio/torch_sox.cpp
torchaudio/torch_sox.cpp
+52
-7
torchaudio/torch_sox.h
torchaudio/torch_sox.h
+14
-2
No files found.
test/test.py
View file @
7314b36d
...
@@ -99,6 +99,15 @@ class Test_LoadSave(unittest.TestCase):
...
@@ -99,6 +99,15 @@ class Test_LoadSave(unittest.TestCase):
torchaudio
.
save
(
sinewave_filepath
,
y
,
sr
)
torchaudio
.
save
(
sinewave_filepath
,
y
,
sr
)
self
.
assertTrue
(
os
.
path
.
isfile
(
sinewave_filepath
))
self
.
assertTrue
(
os
.
path
.
isfile
(
sinewave_filepath
))
# test precision
new_filepath
=
os
.
path
.
join
(
self
.
test_dirpath
,
"test.wav"
)
_
,
_
,
_
,
bp
=
torchaudio
.
info
(
sinewave_filepath
)
torchaudio
.
save
(
new_filepath
,
y
,
sr
,
precision
=
16
)
_
,
_
,
_
,
bp16
=
torchaudio
.
info
(
new_filepath
)
self
.
assertEqual
(
bp
,
32
)
self
.
assertEqual
(
bp16
,
16
)
os
.
unlink
(
new_filepath
)
def
test_load_and_save_is_identity
(
self
):
def
test_load_and_save_is_identity
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
tensor
,
sample_rate
=
torchaudio
.
load
(
input_path
)
tensor
,
sample_rate
=
torchaudio
.
load
(
input_path
)
...
@@ -109,6 +118,48 @@ class Test_LoadSave(unittest.TestCase):
...
@@ -109,6 +118,48 @@ class Test_LoadSave(unittest.TestCase):
self
.
assertEqual
(
sample_rate
,
sample_rate2
)
self
.
assertEqual
(
sample_rate
,
sample_rate2
)
os
.
unlink
(
output_path
)
os
.
unlink
(
output_path
)
def
test_load_partial
(
self
):
num_frames
=
100
offset
=
200
# load entire mono sinewave wav file, load a partial copy and then compare
input_sine_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
x_sine_full
,
sr_sine
=
torchaudio
.
load
(
input_sine_path
)
x_sine_part
,
_
=
torchaudio
.
load
(
input_sine_path
,
num_frames
=
num_frames
,
offset
=
offset
)
l1_error
=
x_sine_full
[
offset
:(
num_frames
+
offset
)].
sub
(
x_sine_part
).
abs
().
sum
().
item
()
# test for the correct number of samples and that the correct portion was loaded
self
.
assertEqual
(
x_sine_part
.
size
(
0
),
num_frames
)
self
.
assertEqual
(
l1_error
,
0.
)
# create a two channel version of this wavefile
x_2ch_sine
=
x_sine_full
.
repeat
(
1
,
2
)
out_2ch_sine_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'2ch_sinewave.wav'
)
torchaudio
.
save
(
out_2ch_sine_path
,
x_2ch_sine
,
sr_sine
)
x_2ch_sine_load
,
_
=
torchaudio
.
load
(
out_2ch_sine_path
,
num_frames
=
num_frames
,
offset
=
offset
)
os
.
unlink
(
out_2ch_sine_path
)
l1_error
=
x_2ch_sine_load
.
sub
(
x_2ch_sine
[
offset
:(
offset
+
num_frames
)]).
abs
().
sum
().
item
()
self
.
assertEqual
(
l1_error
,
0.
)
# test with two channel mp3
x_2ch_full
,
sr_2ch
=
torchaudio
.
load
(
self
.
test_filepath
,
normalization
=
True
)
x_2ch_part
,
_
=
torchaudio
.
load
(
self
.
test_filepath
,
normalization
=
True
,
num_frames
=
num_frames
,
offset
=
offset
)
l1_error
=
x_2ch_full
[
offset
:(
offset
+
num_frames
)].
sub
(
x_2ch_part
).
abs
().
sum
().
item
()
self
.
assertEqual
(
x_2ch_part
.
size
(
0
),
num_frames
)
self
.
assertEqual
(
l1_error
,
0.
)
# check behavior if number of samples would exceed file length
offset_ns
=
300
x_ns
,
_
=
torchaudio
.
load
(
input_sine_path
,
num_frames
=
100000
,
offset
=
offset_ns
)
self
.
assertEqual
(
x_ns
.
size
(
0
),
x_sine_full
.
size
(
0
)
-
offset_ns
)
# check when offset is beyond the end of the file
with
self
.
assertRaises
(
RuntimeError
):
torchaudio
.
load
(
input_sine_path
,
offset
=
100000
)
def
test_get_info
(
self
):
input_path
=
os
.
path
.
join
(
self
.
test_dirpath
,
'assets'
,
'sinewave.wav'
)
info_expected
=
(
1
,
64000
,
16000
,
32
)
info_load
=
torchaudio
.
info
(
input_path
)
self
.
assertEqual
(
info_load
,
info_expected
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
torchaudio/__init__.py
View file @
7314b36d
...
@@ -18,7 +18,7 @@ def check_input(src):
...
@@ -18,7 +18,7 @@ def check_input(src):
raise
TypeError
(
'Expected a CPU based tensor, got %s'
%
type
(
src
))
raise
TypeError
(
'Expected a CPU based tensor, got %s'
%
type
(
src
))
def
load
(
filepath
,
out
=
None
,
normalization
=
None
):
def
load
(
filepath
,
out
=
None
,
normalization
=
None
,
num_frames
=-
1
,
offset
=
0
):
"""Loads an audio file from disk into a Tensor
"""Loads an audio file from disk into a Tensor
Args:
Args:
...
@@ -27,6 +27,8 @@ def load(filepath, out=None, normalization=None):
...
@@ -27,6 +27,8 @@ def load(filepath, out=None, normalization=None):
normalization (bool or number, optional): If boolean `True`, then output is divided by `1 << 31`
normalization (bool or number, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes 16-bit depth audio, and normalizes to `[0, 1]`.
(assumes 16-bit depth audio, and normalizes to `[0, 1]`.
If `number`, then output is divided by that number
If `number`, then output is divided by that number
num_frames (int, optional): number of frames to load. -1 to load everything after the offset.
offset (int, optional): number of frames from the start of the file to begin data loading.
Returns: tuple(Tensor, int)
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels
- Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels
...
@@ -51,7 +53,12 @@ def load(filepath, out=None, normalization=None):
...
@@ -51,7 +53,12 @@ def load(filepath, out=None, normalization=None):
else
:
else
:
out
=
torch
.
FloatTensor
()
out
=
torch
.
FloatTensor
()
sample_rate
=
_torch_sox
.
read_audio_file
(
filepath
,
out
)
if
num_frames
<
-
1
:
raise
ValueError
(
"Expected value for num_samples -1 (entire file) or >=0"
)
if
offset
<
0
:
raise
ValueError
(
"Expected positive offset value"
)
sample_rate
=
_torch_sox
.
read_audio_file
(
filepath
,
out
,
num_frames
,
offset
)
# normalize if needed
# normalize if needed
if
isinstance
(
normalization
,
bool
)
and
normalization
:
if
isinstance
(
normalization
,
bool
)
and
normalization
:
out
/=
1
<<
31
# assuming 16-bit depth
out
/=
1
<<
31
# assuming 16-bit depth
...
@@ -61,7 +68,7 @@ def load(filepath, out=None, normalization=None):
...
@@ -61,7 +68,7 @@ def load(filepath, out=None, normalization=None):
return
out
,
sample_rate
return
out
,
sample_rate
def
save
(
filepath
,
src
,
sample_rate
):
def
save
(
filepath
,
src
,
sample_rate
,
precision
=
32
):
"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc.
"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc.
Args:
Args:
...
@@ -69,6 +76,7 @@ def save(filepath, src, sample_rate):
...
@@ -69,6 +76,7 @@ def save(filepath, src, sample_rate):
src (Tensor): an input 2D Tensor of shape `[L x C]` where L is
src (Tensor): an input 2D Tensor of shape `[L x C]` where L is
the number of audio frames, C is the number of channels
the number of audio frames, C is the number of channels
sample_rate (int): the sample-rate of the audio to be saved
sample_rate (int): the sample-rate of the audio to be saved
precision (int, optional): the bit-precision of the audio to be saved
Example::
Example::
...
@@ -93,6 +101,12 @@ def save(filepath, src, sample_rate):
...
@@ -93,6 +101,12 @@ def save(filepath, src, sample_rate):
sample_rate
=
int
(
sample_rate
)
sample_rate
=
int
(
sample_rate
)
else
:
else
:
raise
TypeError
(
'Sample rate should be a integer'
)
raise
TypeError
(
'Sample rate should be a integer'
)
# check if bit_rate is an integer
if
not
isinstance
(
precision
,
int
):
if
int
(
precision
)
==
precision
:
precision
=
int
(
precision
)
else
:
raise
TypeError
(
'Bit precision should be a integer'
)
# programs such as librosa normalize the signal, unnormalize if detected
# programs such as librosa normalize the signal, unnormalize if detected
if
src
.
min
()
>=
-
1.0
and
src
.
max
()
<=
1.0
:
if
src
.
min
()
>=
-
1.0
and
src
.
max
()
<=
1.0
:
src
=
src
*
(
1
<<
31
)
# assuming 16-bit depth
src
=
src
*
(
1
<<
31
)
# assuming 16-bit depth
...
@@ -100,4 +114,23 @@ def save(filepath, src, sample_rate):
...
@@ -100,4 +114,23 @@ def save(filepath, src, sample_rate):
# save data to file
# save data to file
extension
=
os
.
path
.
splitext
(
filepath
)[
1
]
extension
=
os
.
path
.
splitext
(
filepath
)[
1
]
check_input
(
src
)
check_input
(
src
)
_torch_sox
.
write_audio_file
(
filepath
,
src
,
extension
[
1
:],
sample_rate
)
_torch_sox
.
write_audio_file
(
filepath
,
src
,
extension
[
1
:],
sample_rate
,
precision
)
def
info
(
filepath
):
"""Gets metadata from an audio file without loading the signal.
Args:
filepath (string): path to audio file
Returns: tuple(C, L, sr, precision)
- C (int): number of audio channels
- L (int): length of each channel in frames (samples / channels)
- sr (int): sample rate i.e. samples per second
- precision (float): bit precision i.e. 32-bit or 16-bit audio
Example::
>>> num_channels, length, sample_rate, precision = torchaudio.info('foo.wav')
"""
C
,
L
,
sr
,
bp
=
_torch_sox
.
get_info
(
filepath
)
return
C
,
L
,
sr
,
bp
torchaudio/torch_sox.cpp
View file @
7314b36d
...
@@ -35,8 +35,12 @@ void read_audio(
...
@@ -35,8 +35,12 @@ void read_audio(
SoxDescriptor
&
fd
,
SoxDescriptor
&
fd
,
at
::
Tensor
output
,
at
::
Tensor
output
,
int64_t
number_of_channels
,
int64_t
number_of_channels
,
int64_t
buffer_length
)
{
int64_t
buffer_length
,
int64_t
offset
)
{
std
::
vector
<
sox_sample_t
>
buffer
(
buffer_length
);
std
::
vector
<
sox_sample_t
>
buffer
(
buffer_length
);
if
(
sox_seek
(
fd
.
get
(),
offset
,
0
)
==
SOX_EOF
)
{
throw
std
::
runtime_error
(
"sox_seek reached EOF, try reducing offset or num_samples"
);
}
const
int64_t
samples_read
=
sox_read
(
fd
.
get
(),
buffer
.
data
(),
buffer_length
);
const
int64_t
samples_read
=
sox_read
(
fd
.
get
(),
buffer
.
data
(),
buffer_length
);
if
(
samples_read
==
0
)
{
if
(
samples_read
==
0
)
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
@@ -67,7 +71,11 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
...
@@ -67,7 +71,11 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
}
}
}
// namespace
}
// namespace
int
read_audio_file
(
const
std
::
string
&
file_name
,
at
::
Tensor
output
)
{
int
read_audio_file
(
const
std
::
string
&
file_name
,
at
::
Tensor
output
,
int64_t
nframes
,
int64_t
offset
)
{
SoxDescriptor
fd
(
sox_open_read
(
SoxDescriptor
fd
(
sox_open_read
(
file_name
.
c_str
(),
file_name
.
c_str
(),
/*signal=*/
nullptr
,
/*signal=*/
nullptr
,
...
@@ -79,12 +87,26 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
...
@@ -79,12 +87,26 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
const
int64_t
number_of_channels
=
fd
->
signal
.
channels
;
const
int64_t
number_of_channels
=
fd
->
signal
.
channels
;
const
int
sample_rate
=
fd
->
signal
.
rate
;
const
int
sample_rate
=
fd
->
signal
.
rate
;
const
int64_t
buffer
_length
=
fd
->
signal
.
length
;
const
int64_t
total
_length
=
fd
->
signal
.
length
;
if
(
buffer
_length
==
0
)
{
if
(
total
_length
==
0
)
{
throw
std
::
runtime_error
(
"Error reading audio file: unknown length"
);
throw
std
::
runtime_error
(
"Error reading audio file: unknown length"
);
}
}
read_audio
(
fd
,
output
,
number_of_channels
,
buffer_length
);
// calculate buffer length
int64_t
buffer_length
=
total_length
;
if
(
offset
>
0
&&
offset
<
total_length
)
{
buffer_length
-=
offset
;
}
if
(
nframes
!=
-
1
&&
buffer_length
>
nframes
)
{
// get requested number of frames
buffer_length
=
nframes
;
}
// buffer length and offset need to be multipled by the number of channels
buffer_length
*=
number_of_channels
;
offset
*=
number_of_channels
;
read_audio
(
fd
,
output
,
number_of_channels
,
buffer_length
,
offset
);
return
sample_rate
;
return
sample_rate
;
}
}
...
@@ -93,7 +115,8 @@ void write_audio_file(
...
@@ -93,7 +115,8 @@ void write_audio_file(
const
std
::
string
&
file_name
,
const
std
::
string
&
file_name
,
at
::
Tensor
tensor
,
at
::
Tensor
tensor
,
const
std
::
string
&
extension
,
const
std
::
string
&
extension
,
int
sample_rate
)
{
int
sample_rate
,
int
precision
)
{
if
(
!
tensor
.
is_contiguous
())
{
if
(
!
tensor
.
is_contiguous
())
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"Error writing audio file: input tensor must be contiguous"
);
"Error writing audio file: input tensor must be contiguous"
);
...
@@ -103,7 +126,7 @@ void write_audio_file(
...
@@ -103,7 +126,7 @@ void write_audio_file(
signal
.
rate
=
sample_rate
;
signal
.
rate
=
sample_rate
;
signal
.
channels
=
tensor
.
size
(
1
);
signal
.
channels
=
tensor
.
size
(
1
);
signal
.
length
=
tensor
.
numel
();
signal
.
length
=
tensor
.
numel
();
signal
.
precision
=
32
;
// precision in bits
signal
.
precision
=
precision
;
// precision in bits
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
signal
.
mult
=
nullptr
;
signal
.
mult
=
nullptr
;
...
@@ -129,6 +152,24 @@ void write_audio_file(
...
@@ -129,6 +152,24 @@ void write_audio_file(
"Error writing audio file: could not write entire buffer"
);
"Error writing audio file: could not write entire buffer"
);
}
}
}
}
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info
(
const
std
::
string
&
file_name
)
{
SoxDescriptor
fd
(
sox_open_read
(
file_name
.
c_str
(),
/*signal=*/
nullptr
,
/*encoding=*/
nullptr
,
/*filetype=*/
nullptr
));
if
(
fd
.
get
()
==
nullptr
)
{
throw
std
::
runtime_error
(
"Error opening audio file"
);
}
int64_t
nchannels
=
fd
->
signal
.
channels
;
int64_t
length
=
fd
->
signal
.
length
;
int64_t
sample_rate
=
fd
->
signal
.
rate
;
int64_t
precision
=
fd
->
signal
.
precision
;
return
std
::
make_tuple
(
nchannels
,
length
,
sample_rate
,
precision
);
}
}
// namespace audio
}
// namespace audio
}
// namespace torch
}
// namespace torch
...
@@ -141,4 +182,8 @@ PYBIND11_MODULE(_torch_sox, m) {
...
@@ -141,4 +182,8 @@ PYBIND11_MODULE(_torch_sox, m) {
"write_audio_file"
,
"write_audio_file"
,
&
torch
::
audio
::
write_audio_file
,
&
torch
::
audio
::
write_audio_file
,
"Writes data from a tensor into an audio file"
);
"Writes data from a tensor into an audio file"
);
m
.
def
(
"get_info"
,
&
torch
::
audio
::
get_info
,
"Gets information about an audio file"
);
}
}
torchaudio/torch_sox.h
View file @
7314b36d
...
@@ -10,7 +10,11 @@ namespace torch { namespace audio {
...
@@ -10,7 +10,11 @@ namespace torch { namespace audio {
/// returns the sample rate of the audio file.
/// returns the sample rate of the audio file.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
/// error ocurred during reading of the audio data.
int
read_audio_file
(
const
std
::
string
&
path
,
at
::
Tensor
output
);
int
read_audio_file
(
const
std
::
string
&
path
,
at
::
Tensor
output
,
int64_t
number_of_samples
,
int64_t
offset
);
/// Writes the data of a `Tensor` into an audio file at the given `path`, with
/// Writes the data of a `Tensor` into an audio file at the given `path`, with
/// a certain extension (e.g. `wav`or `mp3`) and sample rate.
/// a certain extension (e.g. `wav`or `mp3`) and sample rate.
...
@@ -20,5 +24,13 @@ void write_audio_file(
...
@@ -20,5 +24,13 @@ void write_audio_file(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
at
::
Tensor
tensor
,
at
::
Tensor
tensor
,
const
std
::
string
&
extension
,
const
std
::
string
&
extension
,
int
sample_rate
);
int
sample_rate
,
int
precision
);
/// Reads an audio file from the given `path` and returns a tuple of
/// the number of channels, length in samples, sample rate, and bits / sec.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
std
::
tuple
<
int64_t
,
int64_t
,
int64_t
,
int64_t
>
get_info
(
const
std
::
string
&
file_name
);
}}
// namespace torch::audio
}}
// namespace torch::audio
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment