Unverified Commit 0406d30d authored by moto's avatar moto Committed by GitHub
Browse files

Replace sox_io save/load with sox effects chain in C++ (#779)

* Replace save/load function with sox effects chain
parent 0812f22a
......@@ -46,6 +46,9 @@ struct TensorInputPriv {
struct TensorOutputPriv {
std::vector<sox_sample_t>* buffer;
};
struct FileOutputPriv {
sox_format_t* sf;
};
/// Callback function to feed Tensor data to SoxEffectChain.
int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
......@@ -84,7 +87,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
/// Callback function to fetch data from SoxEffectChain.
int tensor_output_flow(
sox_effect_t* effp LSX_UNUSED,
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
......@@ -97,6 +100,28 @@ int tensor_output_flow(
return SOX_SUCCESS;
}
int file_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
if (*isamp) {
auto sf = static_cast<FileOutputPriv*>(effp->priv)->sf;
if (sox_write(sf, ibuf, *isamp) != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}
sox_effect_handler_t* get_tensor_input_handler() {
static sox_effect_handler_t handler{/*name=*/"input_tensor",
/*usage=*/NULL,
......@@ -125,6 +150,20 @@ sox_effect_handler_t* get_tensor_output_handler() {
return &handler;
}
sox_effect_handler_t* get_file_output_handler() {
static sox_effect_handler_t handler{/*name=*/"output_file",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/file_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileOutputPriv)};
return &handler;
}
} // namespace
SoxEffectsChain::SoxEffectsChain(
......@@ -134,6 +173,7 @@ SoxEffectsChain::SoxEffectsChain(
out_enc_(output_encoding),
in_sig_(),
interm_sig_(),
out_sig_(),
sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) {
if (!sec_) {
throw std::runtime_error("Failed to create effect chain.");
......@@ -184,6 +224,17 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) {
}
}
void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_file_output_handler()));
static_cast<FileOutputPriv*>(e->priv)->sf = sf;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Failed to add effect: output " << sf->filename;
throw std::runtime_error(stream.str());
}
}
void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
const auto num_args = effect.size();
if (num_args == 0) {
......
......@@ -14,6 +14,7 @@ class SoxEffectsChain {
const sox_encodinginfo_t out_enc_;
sox_signalinfo_t in_sig_;
sox_signalinfo_t interm_sig_;
sox_signalinfo_t out_sig_;
sox_effects_chain_t* sec_;
public:
......@@ -29,6 +30,7 @@ class SoxEffectsChain {
void addInputTensor(torchaudio::sox_utils::TensorSignal* signal);
void addInputFile(sox_format_t* sf);
void addOutputBuffer(std::vector<sox_sample_t>* output_buffer);
void addOutputFile(sox_format_t* sf);
void addEffect(const std::vector<std::string> effect);
int64_t getOutputNumChannels();
int64_t getOutputSampleRate();
......
#include <sox.h>
#include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_effects_chain.h>
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>
......@@ -60,64 +62,21 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));
validate_input_file(sf);
const int64_t num_channels = sf->signal.channels;
const int64_t num_total_samples = sf->signal.length;
const int64_t sample_start = sf->signal.channels * frame_offset;
if (sox_seek(sf, sample_start, 0) == SOX_EOF) {
throw std::runtime_error("Error reading audio file: offset past EOF.");
std::vector<std::vector<std::string>> effects;
if (num_frames != -1) {
std::ostringstream offset, frames;
offset << frame_offset << "s";
frames << "+" << num_frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", offset.str(), frames.str()});
} else if (frame_offset != 0) {
std::ostringstream offset;
offset << frame_offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", offset.str()});
}
const int64_t sample_end = [&]() {
if (num_frames == -1)
return num_total_samples;
const int64_t sample_end_ = num_channels * num_frames + sample_start;
if (num_total_samples < sample_end_) {
// For lossy encoding, it is difficult to predict exact size of buffer for
// reading the number of samples required.
// So we allocate buffer size of given `num_frames` and ask sox to read as
// much as possible. For lossless format, sox reads exact number of
// samples, but for lossy encoding, sox can end up reading less. (i.e.
// mp3) For the consistent behavior specification between lossy/lossless
// format, we allow users to provide `num_frames` value that exceeds #of
// available samples, and we adjust it here.
return num_total_samples;
}
return sample_end_;
}();
const int64_t max_samples = sample_end - sample_start;
// Read samples into buffer
std::vector<sox_sample_t> buffer;
buffer.reserve(max_samples);
const int64_t num_samples = sox_read(sf, buffer.data(), max_samples);
if (num_samples == 0) {
throw std::runtime_error(
"Error reading audio file: empty file or read operation failed.");
}
// NOTE: num_samples may be smaller than max_samples if the input
// format is compressed (i.e. mp3).
// Convert to Tensor
auto tensor = convert_to_tensor(
buffer.data(),
num_samples,
num_channels,
get_dtype(sf->encoding.encoding, sf->signal.precision),
normalize,
channels_first);
return c10::make_intrusive<TensorSignal>(
tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
return torchaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first);
}
void save_audio_file(
......@@ -125,7 +84,6 @@ void save_audio_file(
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression) {
const auto tensor = signal->getTensor();
const auto channels_first = signal->getChannelsFirst();
validate_input_tensor(tensor);
......@@ -146,22 +104,12 @@ void save_audio_file(
throw std::runtime_error("Error saving audio file: failed to open file.");
}
auto tensor_ = tensor;
if (channels_first) {
tensor_ = tensor_.t();
}
const int64_t frames_per_chunk = 65536;
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
chunk = unnormalize_wav(chunk).contiguous();
const size_t numel = chunk.numel();
if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
throw std::runtime_error(
"Error saving audio file: failed to write the entier buffer.");
}
}
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype(), 0.),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(signal.get());
chain.addOutputFile(sf);
chain.run();
}
} // namespace sox_io
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment