Commit c340a8d1 authored by Peter Goldsborough's avatar Peter Goldsborough Committed by Soumith Chintala
Browse files

Conform to sox library a bit better

parent cba11009
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
namespace torch { namespace audio { namespace torch { namespace audio {
int read_audio_file(const std::string& file_name, at::Tensor output) { int read_audio_file(const std::string& file_name, at::Tensor output) {
if (sox_init() != SOX_SUCCESS) {
throw std::runtime_error("Error initializing sox library");
}
sox_format_t* fd = sox_open_read( sox_format_t* fd = sox_open_read(
file_name.c_str(), file_name.c_str(),
/*signal=*/nullptr, /*signal=*/nullptr,
...@@ -25,7 +28,7 @@ int read_audio_file(const std::string& file_name, at::Tensor output) { ...@@ -25,7 +28,7 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
throw std::runtime_error("Error reading audio file: unknown length"); throw std::runtime_error("Error reading audio file: unknown length");
} }
std::vector<int32_t> buffer(buffer_length); std::vector<sox_sample_t> buffer(buffer_length);
const int64_t samples_read = sox_read(fd, buffer.data(), buffer_length); const int64_t samples_read = sox_read(fd, buffer.data(), buffer_length);
if (samples_read == 0) { if (samples_read == 0) {
throw std::runtime_error( throw std::runtime_error(
...@@ -40,6 +43,8 @@ int read_audio_file(const std::string& file_name, at::Tensor output) { ...@@ -40,6 +43,8 @@ int read_audio_file(const std::string& file_name, at::Tensor output) {
std::copy(buffer.begin(), buffer.begin() + samples_read, data); std::copy(buffer.begin(), buffer.begin() + samples_read, data);
}); });
sox_quit();
return sample_rate; return sample_rate;
} }
...@@ -48,12 +53,15 @@ void write_audio_file( ...@@ -48,12 +53,15 @@ void write_audio_file(
at::Tensor tensor, at::Tensor tensor,
const std::string& extension, const std::string& extension,
int sample_rate) { int sample_rate) {
if (sox_init() != SOX_SUCCESS) {
throw std::runtime_error("Error initializing sox library");
}
if (!tensor.is_contiguous()) { if (!tensor.is_contiguous()) {
throw std::runtime_error( throw std::runtime_error(
"Error writing audio file: input tensor must be contiguous"); "Error writing audio file: input tensor must be contiguous");
} }
// Create sox objects and write into int32_t buffer.
sox_signalinfo_t signal; sox_signalinfo_t signal;
signal.rate = sample_rate; signal.rate = sample_rate;
signal.channels = tensor.size(1); signal.channels = tensor.size(1);
...@@ -77,7 +85,7 @@ void write_audio_file( ...@@ -77,7 +85,7 @@ void write_audio_file(
"Error writing audio file: could not open file for writing"); "Error writing audio file: could not open file for writing");
} }
std::vector<int32_t> buffer(tensor.numel()); std::vector<sox_sample_t> buffer(tensor.numel());
AT_DISPATCH_ALL_TYPES(tensor.type(), "write_audio_buffer", [&] { AT_DISPATCH_ALL_TYPES(tensor.type(), "write_audio_buffer", [&] {
auto* data = tensor.data<scalar_t>(); auto* data = tensor.data<scalar_t>();
...@@ -86,8 +94,9 @@ void write_audio_file( ...@@ -86,8 +94,9 @@ void write_audio_file(
const auto samples_written = sox_write(fd, buffer.data(), buffer.size()); const auto samples_written = sox_write(fd, buffer.data(), buffer.size());
// Free buffer and sox structures. // Free buffer and sox structures (before possibly throwing).
sox_close(fd); sox_close(fd);
sox_quit();
if (samples_written != buffer.size()) { if (samples_written != buffer.size()) {
throw std::runtime_error( throw std::runtime_error(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment