Commit c340a8d1 authored by Peter Goldsborough's avatar Peter Goldsborough Committed by Soumith Chintala
Browse files

Conform to sox library a bit better

parent cba11009
......@@ -9,6 +9,9 @@
namespace torch { namespace audio {
int read_audio_file(const std::string& file_name, at::Tensor output) {
if (sox_init() != SOX_SUCCESS) {
throw std::runtime_error("Error initializing sox library");
}
sox_format_t* fd = sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
......@@ -25,7 +28,7 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
throw std::runtime_error("Error reading audio file: unknown length");
}
std::vector<int32_t> buffer(buffer_length);
std::vector<sox_sample_t> buffer(buffer_length);
const int64_t samples_read = sox_read(fd, buffer.data(), buffer_length);
if (samples_read == 0) {
throw std::runtime_error(
......@@ -40,6 +43,8 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
std::copy(buffer.begin(), buffer.begin() + samples_read, data);
});
sox_quit();
return sample_rate;
}
......@@ -48,12 +53,15 @@ void write_audio_file(
at::Tensor tensor,
const std::string& extension,
int sample_rate) {
if (sox_init() != SOX_SUCCESS) {
throw std::runtime_error("Error initializing sox library");
}
if (!tensor.is_contiguous()) {
throw std::runtime_error(
"Error writing audio file: input tensor must be contiguous");
}
// Create sox objects and write into int32_t buffer.
sox_signalinfo_t signal;
signal.rate = sample_rate;
signal.channels = tensor.size(1);
......@@ -77,7 +85,7 @@ void write_audio_file(
"Error writing audio file: could not open file for writing");
}
std::vector<int32_t> buffer(tensor.numel());
std::vector<sox_sample_t> buffer(tensor.numel());
AT_DISPATCH_ALL_TYPES(tensor.type(), "write_audio_buffer", [&] {
auto* data = tensor.data<scalar_t>();
......@@ -86,8 +94,9 @@ void write_audio_file(
const auto samples_written = sox_write(fd, buffer.data(), buffer.size());
// Free buffer and sox structures.
// Free buffer and sox structures (before possibly throwing).
sox_close(fd);
sox_quit();
if (samples_written != buffer.size()) {
throw std::runtime_error(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment