Commit 7314b36d authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

allow loading with offsets and number of samples and saving specified bit precisions (#59)

parent 0b93ff06
......@@ -99,6 +99,15 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.save(sinewave_filepath, y, sr)
self.assertTrue(os.path.isfile(sinewave_filepath))
# test precision
new_filepath = os.path.join(self.test_dirpath, "test.wav")
_, _, _, bp = torchaudio.info(sinewave_filepath)
torchaudio.save(new_filepath, y, sr, precision=16)
_, _, _, bp16 = torchaudio.info(new_filepath)
self.assertEqual(bp, 32)
self.assertEqual(bp16, 16)
os.unlink(new_filepath)
def test_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path)
......@@ -109,6 +118,48 @@ class Test_LoadSave(unittest.TestCase):
self.assertEqual(sample_rate, sample_rate2)
os.unlink(output_path)
def test_load_partial(self):
num_frames = 100
offset = 200
# load entire mono sinewave wav file, load a partial copy and then compare
input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
x_sine_full, sr_sine = torchaudio.load(input_sine_path)
x_sine_part, _ = torchaudio.load(input_sine_path, num_frames=num_frames, offset=offset)
l1_error = x_sine_full[offset:(num_frames+offset)].sub(x_sine_part).abs().sum().item()
# test for the correct number of samples and that the correct portion was loaded
self.assertEqual(x_sine_part.size(0), num_frames)
self.assertEqual(l1_error, 0.)
# create a two channel version of this wavefile
x_2ch_sine = x_sine_full.repeat(1, 2)
out_2ch_sine_path = os.path.join(self.test_dirpath, 'assets', '2ch_sinewave.wav')
torchaudio.save(out_2ch_sine_path, x_2ch_sine, sr_sine)
x_2ch_sine_load, _ = torchaudio.load(out_2ch_sine_path, num_frames=num_frames, offset=offset)
os.unlink(out_2ch_sine_path)
l1_error = x_2ch_sine_load.sub(x_2ch_sine[offset:(offset + num_frames)]).abs().sum().item()
self.assertEqual(l1_error, 0.)
# test with two channel mp3
x_2ch_full, sr_2ch = torchaudio.load(self.test_filepath, normalization=True)
x_2ch_part, _ = torchaudio.load(self.test_filepath, normalization=True, num_frames=num_frames, offset=offset)
l1_error = x_2ch_full[offset:(offset+num_frames)].sub(x_2ch_part).abs().sum().item()
self.assertEqual(x_2ch_part.size(0), num_frames)
self.assertEqual(l1_error, 0.)
# check behavior if number of samples would exceed file length
offset_ns = 300
x_ns, _ = torchaudio.load(input_sine_path, num_frames=100000, offset=offset_ns)
self.assertEqual(x_ns.size(0), x_sine_full.size(0) - offset_ns)
# check when offset is beyond the end of the file
with self.assertRaises(RuntimeError):
torchaudio.load(input_sine_path, offset=100000)
def test_get_info(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
info_expected = (1, 64000, 16000, 32)
info_load = torchaudio.info(input_path)
self.assertEqual(info_load, info_expected)
if __name__ == '__main__':
unittest.main()
......@@ -18,7 +18,7 @@ def check_input(src):
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filepath, out=None, normalization=None):
def load(filepath, out=None, normalization=None, num_frames=-1, offset=0):
"""Loads an audio file from disk into a Tensor
Args:
......@@ -27,6 +27,8 @@ def load(filepath, out=None, normalization=None):
normalization (bool or number, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes 16-bit depth audio, and normalizes to `[0, 1]`.
If `number`, then output is divided by that number
num_frames (int, optional): number of frames to load. -1 to load everything after the offset.
offset (int, optional): number of frames from the start of the file to begin data loading.
Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels
......@@ -51,7 +53,12 @@ def load(filepath, out=None, normalization=None):
else:
out = torch.FloatTensor()
sample_rate = _torch_sox.read_audio_file(filepath, out)
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if offset < 0:
raise ValueError("Expected positive offset value")
sample_rate = _torch_sox.read_audio_file(filepath, out, num_frames, offset)
# normalize if needed
if isinstance(normalization, bool) and normalization:
out /= 1 << 31 # assuming 16-bit depth
......@@ -61,7 +68,7 @@ def load(filepath, out=None, normalization=None):
return out, sample_rate
def save(filepath, src, sample_rate):
def save(filepath, src, sample_rate, precision=32):
"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc.
Args:
......@@ -69,6 +76,7 @@ def save(filepath, src, sample_rate):
src (Tensor): an input 2D Tensor of shape `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate (int): the sample-rate of the audio to be saved
precision (int, optional): the bit-precision of the audio to be saved
Example::
......@@ -93,6 +101,12 @@ def save(filepath, src, sample_rate):
sample_rate = int(sample_rate)
else:
raise TypeError('Sample rate should be a integer')
# check if bit_rate is an integer
if not isinstance(precision, int):
if int(precision) == precision:
precision = int(precision)
else:
raise TypeError('Bit precision should be a integer')
# programs such as librosa normalize the signal, unnormalize if detected
if src.min() >= -1.0 and src.max() <= 1.0:
src = src * (1 << 31) # assuming 16-bit depth
......@@ -100,4 +114,23 @@ def save(filepath, src, sample_rate):
# save data to file
extension = os.path.splitext(filepath)[1]
check_input(src)
_torch_sox.write_audio_file(filepath, src, extension[1:], sample_rate)
_torch_sox.write_audio_file(filepath, src, extension[1:], sample_rate, precision)
def info(filepath):
"""Gets metadata from an audio file without loading the signal.
Args:
filepath (string): path to audio file
Returns: tuple(C, L, sr, precision)
- C (int): number of audio channels
- L (int): length of each channel in frames (samples / channels)
- sr (int): sample rate i.e. samples per second
- precision (float): bit precision i.e. 32-bit or 16-bit audio
Example::
>>> num_channels, length, sample_rate, precision = torchaudio.info('foo.wav')
"""
C, L, sr, bp = _torch_sox.get_info(filepath)
return C, L, sr, bp
......@@ -35,8 +35,12 @@ void read_audio(
SoxDescriptor& fd,
at::Tensor output,
int64_t number_of_channels,
int64_t buffer_length) {
int64_t buffer_length,
int64_t offset) {
std::vector<sox_sample_t> buffer(buffer_length);
if (sox_seek(fd.get(), offset, 0) == SOX_EOF) {
throw std::runtime_error("sox_seek reached EOF, try reducing offset or num_samples");
}
const int64_t samples_read = sox_read(fd.get(), buffer.data(), buffer_length);
if (samples_read == 0) {
throw std::runtime_error(
......@@ -67,7 +71,11 @@ int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
}
} // namespace
int read_audio_file(const std::string& file_name, at::Tensor output) {
int read_audio_file(
const std::string& file_name,
at::Tensor output,
int64_t nframes,
int64_t offset) {
SoxDescriptor fd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
......@@ -79,12 +87,26 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
const int64_t number_of_channels = fd->signal.channels;
const int sample_rate = fd->signal.rate;
const int64_t buffer_length = fd->signal.length;
if (buffer_length == 0) {
const int64_t total_length = fd->signal.length;
if (total_length == 0) {
throw std::runtime_error("Error reading audio file: unknown length");
}
read_audio(fd, output, number_of_channels, buffer_length);
// calculate buffer length
int64_t buffer_length = total_length;
if (offset > 0 && offset < total_length) {
buffer_length -= offset;
}
if (nframes != -1 && buffer_length > nframes) {
// get requested number of frames
buffer_length = nframes;
}
// buffer length and offset need to be multipled by the number of channels
buffer_length *= number_of_channels;
offset *= number_of_channels;
read_audio(fd, output, number_of_channels, buffer_length, offset);
return sample_rate;
}
......@@ -93,7 +115,8 @@ void write_audio_file(
const std::string& file_name,
at::Tensor tensor,
const std::string& extension,
int sample_rate) {
int sample_rate,
int precision) {
if (!tensor.is_contiguous()) {
throw std::runtime_error(
"Error writing audio file: input tensor must be contiguous");
......@@ -103,7 +126,7 @@ void write_audio_file(
signal.rate = sample_rate;
signal.channels = tensor.size(1);
signal.length = tensor.numel();
signal.precision = 32; // precision in bits
signal.precision = precision; // precision in bits
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
signal.mult = nullptr;
......@@ -129,6 +152,24 @@ void write_audio_file(
"Error writing audio file: could not write entire buffer");
}
}
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info(
const std::string& file_name
) {
SoxDescriptor fd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));
if (fd.get() == nullptr) {
throw std::runtime_error("Error opening audio file");
}
int64_t nchannels = fd->signal.channels;
int64_t length = fd->signal.length;
int64_t sample_rate = fd->signal.rate;
int64_t precision = fd->signal.precision;
return std::make_tuple(nchannels, length, sample_rate, precision);
}
} // namespace audio
} // namespace torch
......@@ -141,4 +182,8 @@ PYBIND11_MODULE(_torch_sox, m) {
"write_audio_file",
&torch::audio::write_audio_file,
"Writes data from a tensor into an audio file");
m.def(
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
}
......@@ -10,7 +10,11 @@ namespace torch { namespace audio {
/// returns the sample rate of the audio file.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
int read_audio_file(const std::string& path, at::Tensor output);
int read_audio_file(
const std::string& path,
at::Tensor output,
int64_t number_of_samples,
int64_t offset);
/// Writes the data of a `Tensor` into an audio file at the given `path`, with
/// a certain extension (e.g. `wav`or `mp3`) and sample rate.
......@@ -20,5 +24,13 @@ void write_audio_file(
const std::string& path,
at::Tensor tensor,
const std::string& extension,
int sample_rate);
int sample_rate,
int precision);
/// Reads an audio file from the given `path` and returns a tuple of
/// the number of channels, length in samples, sample rate, and bits / sec.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
std::tuple<int64_t, int64_t, int64_t, int64_t> get_info(
const std::string& file_name);
}} // namespace torch::audio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment