You need to sign in or sign up before continuing.
Commit 7314b36d authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

allow loading with offsets and number of samples and saving specified bit precisions (#59)

parent 0b93ff06
......@@ -99,6 +99,15 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.save(sinewave_filepath, y, sr)
self.assertTrue(os.path.isfile(sinewave_filepath))
# test precision
new_filepath = os.path.join(self.test_dirpath, "test.wav")
_, _, _, bp = torchaudio.info(sinewave_filepath)
torchaudio.save(new_filepath, y, sr, precision=16)
_, _, _, bp16 = torchaudio.info(new_filepath)
self.assertEqual(bp, 32)
self.assertEqual(bp16, 16)
os.unlink(new_filepath)
def test_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path)
......@@ -109,6 +118,48 @@ class Test_LoadSave(unittest.TestCase):
self.assertEqual(sample_rate, sample_rate2)
os.unlink(output_path)
def test_load_partial(self):
num_frames = 100
offset = 200
# load entire mono sinewave wav file, load a partial copy and then compare
input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
x_sine_full, sr_sine = torchaudio.load(input_sine_path)
x_sine_part, _ = torchaudio.load(input_sine_path, num_frames=num_frames, offset=offset)
l1_error = x_sine_full[offset:(num_frames+offset)].sub(x_sine_part).abs().sum().item()
# test for the correct number of samples and that the correct portion was loaded
self.assertEqual(x_sine_part.size(0), num_frames)
self.assertEqual(l1_error, 0.)
# create a two channel version of this wavefile
x_2ch_sine = x_sine_full.repeat(1, 2)
out_2ch_sine_path = os.path.join(self.test_dirpath, 'assets', '2ch_sinewave.wav')
torchaudio.save(out_2ch_sine_path, x_2ch_sine, sr_sine)
x_2ch_sine_load, _ = torchaudio.load(out_2ch_sine_path, num_frames=num_frames, offset=offset)
os.unlink(out_2ch_sine_path)
l1_error = x_2ch_sine_load.sub(x_2ch_sine[offset:(offset + num_frames)]).abs().sum().item()
self.assertEqual(l1_error, 0.)
# test with two channel mp3
x_2ch_full, sr_2ch = torchaudio.load(self.test_filepath, normalization=True)
x_2ch_part, _ = torchaudio.load(self.test_filepath, normalization=True, num_frames=num_frames, offset=offset)
l1_error = x_2ch_full[offset:(offset+num_frames)].sub(x_2ch_part).abs().sum().item()
self.assertEqual(x_2ch_part.size(0), num_frames)
self.assertEqual(l1_error, 0.)
# check behavior if number of samples would exceed file length
offset_ns = 300
x_ns, _ = torchaudio.load(input_sine_path, num_frames=100000, offset=offset_ns)
self.assertEqual(x_ns.size(0), x_sine_full.size(0) - offset_ns)
# check when offset is beyond the end of the file
with self.assertRaises(RuntimeError):
torchaudio.load(input_sine_path, offset=100000)
def test_get_info(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
info_expected = (1, 64000, 16000, 32)
info_load = torchaudio.info(input_path)
self.assertEqual(info_load, info_expected)
if __name__ == '__main__':
unittest.main()
......@@ -18,7 +18,7 @@ def check_input(src):
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filepath, out=None, normalization=None):
def load(filepath, out=None, normalization=None, num_frames=-1, offset=0):
"""Loads an audio file from disk into a Tensor
Args:
......@@ -27,6 +27,8 @@ def load(filepath, out=None, normalization=None):
normalization (bool or number, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes 16-bit depth audio, and normalizes to `[0, 1]`.
If `number`, then output is divided by that number
num_frames (int, optional): number of frames to load. -1 to load everything after the offset.
offset (int, optional): number of frames from the start of the file to begin data loading.
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels
......@@ -51,7 +53,12 @@ def load(filepath, out=None, normalization=None):
else:
out = torch.FloatTensor()
sample_rate = _torch_sox.read_audio_file(filepath, out)
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if offset < 0:
raise ValueError("Expected positive offset value")
sample_rate = _torch_sox.read_audio_file(filepath, out, num_frames, offset)
# normalize if needed
if isinstance(normalization, bool) and normalization:
out /= 1 << 31 # assuming 16-bit depth
......@@ -61,7 +68,7 @@ def load(filepath, out=None, normalization=None):
return out, sample_rate
def save(filepath, src, sample_rate):
def save(filepath, src, sample_rate, precision=32):
"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc.
Args:
......@@ -69,6 +76,7 @@ def save(filepath, src, sample_rate):
src (Tensor): an input 2D Tensor of shape `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate (int): the sample-rate of the audio to be saved
precision (int, optional): the bit-precision of the audio to be saved
Example::
......@@ -93,6 +101,12 @@ def save(filepath, src, sample_rate):
sample_rate = int(sample_rate)
else:
raise TypeError('Sample rate should be a integer')
# check if bit_rate is an integer
if not isinstance(precision, int):
if int(precision) == precision:
precision = int(precision)
else:
raise TypeError('Bit precision should be a integer')
# programs such as librosa normalize the signal, unnormalize if detected
if src.min() >= -1.0 and src.max() <= 1.0:
src = src * (1 << 31) # assuming 16-bit depth
......@@ -100,4 +114,23 @@ def save(filepath, src, sample_rate):
# save data to file
extension = os.path.splitext(filepath)[1]
check_input(src)
_torch_sox.write_audio_file(filepath, src, extension[1:], sample_rate)
_torch_sox.write_audio_file(filepath, src, extension[1:], sample_rate, precision)
def info(filepath):
"""Gets metadata from an audio file without loading the signal.
Args:
filepath (string): path to audio file
Returns: tuple(C, L, sr, precision)
- C (int): number of audio channels
- L (int): length of each channel in frames (samples / channels)
- sr (int): sample rate i.e. samples per second
- precision (float): bit precision i.e. 32-bit or 16-bit audio
Example::
>>> num_channels, length, sample_rate, precision = torchaudio.info('foo.wav')
"""
C, L, sr, bp = _torch_sox.get_info(filepath)
return C, L, sr, bp
......@@ -35,8 +35,12 @@ void read_audio(
SoxDescriptor& fd,
at::Tensor output,
int64_t number_of_channels,
int64_t buffer_length) {
int64_t buffer_length,
int64_t offset) {
std::vector<sox_sample_t> buffer(buffer_length);
if (sox_seek(fd.get(), offset, 0) == SOX_EOF) {
throw std::runtime_error("sox_seek reached EOF, try reducing offset or num_samples");
}
const int64_t samples_read = sox_read(fd.get(), buffer.data(), buffer_length);
if (samples_read == 0) {
throw std::runtime_error(
......@@ -67,7 +71,11 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
}
} // namespace
int read_audio_file(const std::string& file_name, at::Tensor output) {
int read_audio_file(
const std::string& file_name,
at::Tensor output,
int64_t nframes,
int64_t offset) {
SoxDescriptor fd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
......@@ -79,12 +87,26 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
const int64_t number_of_channels = fd->signal.channels;
const int sample_rate = fd->signal.rate;
const int64_t buffer_length = fd->signal.length;
if (buffer_length == 0) {
const int64_t total_length = fd->signal.length;
if (total_length == 0) {
throw std::runtime_error("Error reading audio file: unknown length");
}
read_audio(fd, output, number_of_channels, buffer_length);
// calculate buffer length
int64_t buffer_length = total_length;
if (offset > 0 && offset < total_length) {
buffer_length -= offset;
}
if (nframes != -1 && buffer_length > nframes) {
// get requested number of frames
buffer_length = nframes;
}
// buffer length and offset need to be multipled by the number of channels
buffer_length *= number_of_channels;
offset *= number_of_channels;
read_audio(fd, output, number_of_channels, buffer_length, offset);
return sample_rate;
}
......@@ -93,7 +115,8 @@ void write_audio_file(
const std::string& file_name,
at::Tensor tensor,
const std::string& extension,
int sample_rate) {
int sample_rate,
int precision) {
if (!tensor.is_contiguous()) {
throw std::runtime_error(
"Error writing audio file: input tensor must be contiguous");
......@@ -103,7 +126,7 @@ void write_audio_file(
signal.rate = sample_rate;
signal.channels = tensor.size(1);
signal.length = tensor.numel();
signal.precision = 32; // precision in bits
signal.precision = precision; // precision in bits
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
signal.mult = nullptr;
......@@ -129,6 +152,24 @@ void write_audio_file(
"Error writing audio file: could not write entire buffer");
}
}
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info(
const std::string& file_name
) {
SoxDescriptor fd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));
if (fd.get() == nullptr) {
throw std::runtime_error("Error opening audio file");
}
int64_t nchannels = fd->signal.channels;
int64_t length = fd->signal.length;
int64_t sample_rate = fd->signal.rate;
int64_t precision = fd->signal.precision;
return std::make_tuple(nchannels, length, sample_rate, precision);
}
} // namespace audio
} // namespace torch
......@@ -141,4 +182,8 @@ PYBIND11_MODULE(_torch_sox, m) {
"write_audio_file",
&torch::audio::write_audio_file,
"Writes data from a tensor into an audio file");
m.def(
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
}
......@@ -10,7 +10,11 @@ namespace torch { namespace audio {
/// returns the sample rate of the audio file.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
int read_audio_file(const std::string& path, at::Tensor output);
int read_audio_file(
const std::string& path,
at::Tensor output,
int64_t number_of_samples,
int64_t offset);
/// Writes the data of a `Tensor` into an audio file at the given `path`, with
/// a certain extension (e.g. `wav`or `mp3`) and sample rate.
......@@ -20,5 +24,13 @@ void write_audio_file(
const std::string& path,
at::Tensor tensor,
const std::string& extension,
int sample_rate);
int sample_rate,
int precision);
/// Reads an audio file from the given `path` and returns a tuple of
/// the number of channels, length in samples, sample rate, and bits / sec.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info(
const std::string& file_name);
}} // namespace torch::audio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment