Unverified Commit 08f188b2 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Restructure C++ code to allow per file registration of custom ops (#1221)


Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 4608a5b2
...@@ -6,7 +6,7 @@ from torchaudio import ( ...@@ -6,7 +6,7 @@ from torchaudio import (
kaldi_io, kaldi_io,
utils, utils,
sox_effects, sox_effects,
transforms transforms,
) )
USE_SOUNDFILE_LEGACY_INTERFACE = None USE_SOUNDFILE_LEGACY_INTERFACE = None
......
...@@ -162,12 +162,9 @@ def load( ...@@ -162,12 +162,9 @@ def load(
if hasattr(filepath, 'read'): if hasattr(filepath, 'read'):
return torchaudio._torchaudio.load_audio_fileobj( return torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format) filepath, frame_offset, num_frames, normalize, channels_first, format)
signal = torch.ops.torchaudio.sox_io_load_audio_file( filepath = os.fspath(filepath)
os.fspath(filepath), frame_offset, num_frames, normalize, channels_first, format) return torch.ops.torchaudio.sox_io_load_audio_file(
return signal.get_tensor(), signal.get_sample_rate()
signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format) filepath, frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
@torch.jit.unused @torch.jit.unused
......
...@@ -50,24 +50,25 @@ void shutdown_sox_effects() { ...@@ -50,24 +50,25 @@ void shutdown_sox_effects() {
} }
} }
c10::intrusive_ptr<TensorSignal> apply_effects_tensor( std::tuple<torch::Tensor, int64_t> apply_effects_tensor(
const c10::intrusive_ptr<TensorSignal>& input_signal, torch::Tensor waveform,
std::vector<std::vector<std::string>> effects) { int64_t sample_rate,
auto in_tensor = input_signal->getTensor(); std::vector<std::vector<std::string>> effects,
validate_input_tensor(in_tensor); bool channels_first) {
validate_input_tensor(waveform);
// Create SoxEffectsChain // Create SoxEffectsChain
const auto dtype = in_tensor.dtype(); const auto dtype = waveform.dtype();
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", dtype), /*input_encoding=*/get_encodinginfo("wav", dtype),
/*output_encoding=*/get_encodinginfo("wav", dtype)); /*output_encoding=*/get_encodinginfo("wav", dtype));
// Prepare output buffer // Prepare output buffer
std::vector<sox_sample_t> out_buffer; std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(in_tensor.numel()); out_buffer.reserve(waveform.numel());
// Build and run effects chain // Build and run effects chain
chain.addInputTensor(input_signal.get()); chain.addInputTensor(&waveform, sample_rate, channels_first);
for (const auto& effect : effects) { for (const auto& effect : effects) {
chain.addEffect(effect); chain.addEffect(effect);
} }
...@@ -75,7 +76,6 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor( ...@@ -75,7 +76,6 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
chain.run(); chain.run();
// Create tensor from buffer // Create tensor from buffer
const auto channels_first = input_signal->getChannelsFirst();
auto out_tensor = convert_to_tensor( auto out_tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(), /*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(), /*num_samples=*/out_buffer.size(),
...@@ -84,11 +84,11 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor( ...@@ -84,11 +84,11 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
/*noramlize=*/false, /*noramlize=*/false,
channels_first); channels_first);
return c10::make_intrusive<TensorSignal>( return std::tuple<torch::Tensor, int64_t>(
out_tensor, chain.getOutputSampleRate(), channels_first); out_tensor, chain.getOutputSampleRate());
} }
c10::intrusive_ptr<TensorSignal> apply_effects_file( std::tuple<torch::Tensor, int64_t> apply_effects_file(
const std::string path, const std::string path,
std::vector<std::vector<std::string>> effects, std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize, c10::optional<bool>& normalize,
...@@ -131,8 +131,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file( ...@@ -131,8 +131,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
normalize.value_or(true), normalize.value_or(true),
channels_first_); channels_first_);
return c10::make_intrusive<TensorSignal>( return std::tuple<torch::Tensor, int64_t>(
tensor, chain.getOutputSampleRate(), channels_first_); tensor, chain.getOutputSampleRate());
} }
#ifdef TORCH_API_INCLUDE_EXTENSION_H #ifdef TORCH_API_INCLUDE_EXTENSION_H
...@@ -238,5 +238,20 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj( ...@@ -238,5 +238,20 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
#endif // TORCH_API_INCLUDE_EXTENSION_H #endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
&torchaudio::sox_effects::initialize_sox_effects);
m.def(
"torchaudio::sox_effects_shutdown_sox_effects",
&torchaudio::sox_effects::shutdown_sox_effects);
m.def(
"torchaudio::sox_effects_apply_effects_tensor",
&torchaudio::sox_effects::apply_effects_tensor);
m.def(
"torchaudio::sox_effects_apply_effects_file",
&torchaudio::sox_effects::apply_effects_file);
}
} // namespace sox_effects } // namespace sox_effects
} // namespace torchaudio } // namespace torchaudio
...@@ -15,11 +15,13 @@ void initialize_sox_effects(); ...@@ -15,11 +15,13 @@ void initialize_sox_effects();
void shutdown_sox_effects(); void shutdown_sox_effects();
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_tensor( std::tuple<torch::Tensor, int64_t> apply_effects_tensor(
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& input_signal, torch::Tensor waveform,
std::vector<std::vector<std::string>> effects); int64_t sample_rate,
std::vector<std::vector<std::string>> effects,
bool channels_first);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file( std::tuple<torch::Tensor, int64_t> apply_effects_file(
const std::string path, const std::string path,
std::vector<std::vector<std::string>> effects, std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize, c10::optional<bool>& normalize,
......
...@@ -36,12 +36,14 @@ struct SoxEffect { ...@@ -36,12 +36,14 @@ struct SoxEffect {
/// helper classes for passing the location of input tensor and output buffer /// helper classes for passing the location of input tensor and output buffer
/// ///
/// drain/flow callback functions require plaing C style function signature and /// drain/flow callback functions require plaing C style function signature and
/// the way to pass extra data is to attach data to sox_fffect_t::priv pointer. /// the way to pass extra data is to attach data to sox_effect_t::priv pointer.
/// The following structs will be assigned to sox_fffect_t::priv pointer which /// The following structs will be assigned to sox_effect_t::priv pointer which
/// gives sox_effect_t an access to input Tensor and output buffer object. /// gives sox_effect_t an access to input Tensor and output buffer object.
struct TensorInputPriv { struct TensorInputPriv {
size_t index; size_t index;
TensorSignal* signal; torch::Tensor* waveform;
int64_t sample_rate;
bool channels_first;
}; };
struct TensorOutputPriv { struct TensorOutputPriv {
std::vector<sox_sample_t>* buffer; std::vector<sox_sample_t>* buffer;
...@@ -55,8 +57,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -55,8 +57,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
// Retrieve the input Tensor and current index // Retrieve the input Tensor and current index
auto priv = static_cast<TensorInputPriv*>(effp->priv); auto priv = static_cast<TensorInputPriv*>(effp->priv);
auto index = priv->index; auto index = priv->index;
auto signal = priv->signal; auto tensor = *(priv->waveform);
auto tensor = signal->getTensor();
auto num_channels = effp->out_signal.channels; auto num_channels = effp->out_signal.channels;
// Adjust the number of samples to read // Adjust the number of samples to read
...@@ -71,7 +72,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { ...@@ -71,7 +72,7 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
const auto tensor_ = [&]() { const auto tensor_ = [&]() {
auto i_frame = index / num_channels; auto i_frame = index / num_channels;
auto num_frames = *osamp / num_channels; auto num_frames = *osamp / num_channels;
auto t = (signal->getChannelsFirst()) auto t = (priv->channels_first)
? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t() ? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t()
: tensor.index({Slice(i_frame, i_frame + num_frames), Slice()}); : tensor.index({Slice(i_frame, i_frame + num_frames), Slice()});
return unnormalize_wav(t.reshape({-1})).contiguous(); return unnormalize_wav(t.reshape({-1})).contiguous();
...@@ -193,13 +194,18 @@ void SoxEffectsChain::run() { ...@@ -193,13 +194,18 @@ void SoxEffectsChain::run() {
sox_flow_effects(sec_, NULL, NULL); sox_flow_effects(sec_, NULL, NULL);
} }
void SoxEffectsChain::addInputTensor(TensorSignal* signal) { void SoxEffectsChain::addInputTensor(
in_sig_ = get_signalinfo(signal, "wav"); torch::Tensor* waveform,
int64_t sample_rate,
bool channels_first) {
in_sig_ = get_signalinfo(waveform, sample_rate, "wav", channels_first);
interm_sig_ = in_sig_; interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_tensor_input_handler())); SoxEffect e(sox_create_effect(get_tensor_input_handler()));
auto priv = static_cast<TensorInputPriv*>(e->priv); auto priv = static_cast<TensorInputPriv*>(e->priv);
priv->signal = signal;
priv->index = 0; priv->index = 0;
priv->waveform = waveform;
priv->sample_rate = sample_rate;
priv->channels_first = channels_first;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error( throw std::runtime_error(
"Internal Error: Failed to add effect: input_tensor"); "Internal Error: Failed to add effect: input_tensor");
......
...@@ -30,7 +30,10 @@ class SoxEffectsChain { ...@@ -30,7 +30,10 @@ class SoxEffectsChain {
SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete; SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete;
~SoxEffectsChain(); ~SoxEffectsChain();
void run(); void run();
void addInputTensor(torchaudio::sox_utils::TensorSignal* signal); void addInputTensor(
torch::Tensor* waveform,
int64_t sample_rate,
bool channels_first);
void addInputFile(sox_format_t* sf); void addInputFile(sox_format_t* sf);
void addOutputBuffer(std::vector<sox_sample_t>* output_buffer); void addOutputBuffer(std::vector<sox_sample_t>* output_buffer);
void addOutputFile(sox_format_t* sf); void addOutputFile(sox_format_t* sf);
......
...@@ -131,7 +131,7 @@ std::vector<std::vector<std::string>> get_effects( ...@@ -131,7 +131,7 @@ std::vector<std::vector<std::string>> get_effects(
} // namespace } // namespace
c10::intrusive_ptr<TensorSignal> load_audio_file( std::tuple<torch::Tensor, int64_t> load_audio_file(
const std::string& path, const std::string& path,
c10::optional<int64_t>& frame_offset, c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames, c10::optional<int64_t>& num_frames,
...@@ -153,7 +153,6 @@ void save_audio_file( ...@@ -153,7 +153,6 @@ void save_audio_file(
c10::optional<std::string> dtype) { c10::optional<std::string> dtype) {
validate_input_tensor(tensor); validate_input_tensor(tensor);
auto signal = TensorSignal(tensor, sample_rate, channels_first);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) { if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error( throw std::runtime_error(
"dtype conversion only supported for float32 tensors"); "dtype conversion only supported for float32 tensors");
...@@ -174,7 +173,8 @@ void save_audio_file( ...@@ -174,7 +173,8 @@ void save_audio_file(
num_channels == 1, "amr-nb format only supports single channel audio."); num_channels == 1, "amr-nb format only supports single channel audio.");
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = get_signalinfo(&signal, filetype); const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
SoxFormat sf(sox_open_write( SoxFormat sf(sox_open_write(
...@@ -192,7 +192,7 @@ void save_audio_file( ...@@ -192,7 +192,7 @@ void save_audio_file(
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()),
/*output_encoding=*/sf->encoding); /*output_encoding=*/sf->encoding);
chain.addInputTensor(&signal); chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFile(sf); chain.addOutputFile(sf);
chain.run(); chain.run();
} }
...@@ -294,7 +294,6 @@ void save_audio_fileobj( ...@@ -294,7 +294,6 @@ void save_audio_fileobj(
c10::optional<std::string> dtype) { c10::optional<std::string> dtype) {
validate_input_tensor(tensor); validate_input_tensor(tensor);
auto signal = TensorSignal(tensor, sample_rate, channels_first);
if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) { if (tensor.dtype() != torch::kFloat32 && dtype.has_value()) {
throw std::runtime_error( throw std::runtime_error(
"dtype conversion only supported for float32 tensors"); "dtype conversion only supported for float32 tensors");
...@@ -312,7 +311,8 @@ void save_audio_fileobj( ...@@ -312,7 +311,8 @@ void save_audio_fileobj(
} }
tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16); tensor = (unnormalize_wav(tensor) / 65536).to(torch::kInt16);
} }
const auto signal_info = get_signalinfo(&signal, filetype); const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression); const auto encoding_info = get_encodinginfo(filetype, tgt_dtype, compression);
AutoReleaseBuffer buffer; AutoReleaseBuffer buffer;
...@@ -333,7 +333,7 @@ void save_audio_fileobj( ...@@ -333,7 +333,7 @@ void save_audio_fileobj(
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_encodinginfo("wav", tensor.dtype()), /*input_encoding=*/get_encodinginfo("wav", tensor.dtype()),
/*output_encoding=*/sf->encoding); /*output_encoding=*/sf->encoding);
chain.addInputTensor(&signal); chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj); chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
chain.run(); chain.run();
...@@ -346,5 +346,24 @@ void save_audio_fileobj( ...@@ -346,5 +346,24 @@ void save_audio_fileobj(
#endif // TORCH_API_INCLUDE_EXTENSION_H #endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def(
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample)
.def("get_encoding", &torchaudio::sox_io::SignalInfo::getEncoding);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
"torchaudio::sox_io_load_audio_file",
&torchaudio::sox_io::load_audio_file);
m.def(
"torchaudio::sox_io_save_audio_file",
&torchaudio::sox_io::save_audio_file);
}
} // namespace sox_io } // namespace sox_io
} // namespace torchaudio } // namespace torchaudio
...@@ -35,7 +35,7 @@ c10::intrusive_ptr<SignalInfo> get_info_file( ...@@ -35,7 +35,7 @@ c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path, const std::string& path,
c10::optional<std::string>& format); c10::optional<std::string>& format);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file( std::tuple<torch::Tensor, int64_t> load_audio_file(
const std::string& path, const std::string& path,
c10::optional<int64_t>& frame_offset, c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames, c10::optional<int64_t>& num_frames,
......
#include <torchaudio/csrc/sox/effects.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/utils.h>
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
//////////////////////////////////////////////////////////////////////////////
// sox_utils.h
//////////////////////////////////////////////////////////////////////////////
m.class_<torchaudio::sox_utils::TensorSignal>("TensorSignal")
.def(torch::init<torch::Tensor, int64_t, bool>())
.def("get_tensor", &torchaudio::sox_utils::TensorSignal::getTensor)
.def(
"get_sample_rate",
&torchaudio::sox_utils::TensorSignal::getSampleRate)
.def(
"get_channels_first",
&torchaudio::sox_utils::TensorSignal::getChannelsFirst);
m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed);
m.def(
"torchaudio::sox_utils_set_verbosity",
&torchaudio::sox_utils::set_verbosity);
m.def(
"torchaudio::sox_utils_set_use_threads",
&torchaudio::sox_utils::set_use_threads);
m.def(
"torchaudio::sox_utils_set_buffer_size",
&torchaudio::sox_utils::set_buffer_size);
m.def(
"torchaudio::sox_utils_list_effects",
&torchaudio::sox_utils::list_effects);
m.def(
"torchaudio::sox_utils_list_read_formats",
&torchaudio::sox_utils::list_read_formats);
m.def(
"torchaudio::sox_utils_list_write_formats",
&torchaudio::sox_utils::list_write_formats);
//////////////////////////////////////////////////////////////////////////////
// sox_io.h
//////////////////////////////////////////////////////////////////////////////
m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def(
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample)
.def("get_encoding", &torchaudio::sox_io::SignalInfo::getEncoding);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
"torchaudio::sox_io_load_audio_file("
"str path,"
"int? frame_offset=None,"
"int? num_frames=None,"
"bool? normalize=True,"
"bool? channels_first=False,"
"str? format=None"
") -> __torch__.torch.classes.torchaudio.TensorSignal",
&torchaudio::sox_io::load_audio_file);
m.def(
"torchaudio::sox_io_save_audio_file",
&torchaudio::sox_io::save_audio_file);
//////////////////////////////////////////////////////////////////////////////
// sox_effects.h
//////////////////////////////////////////////////////////////////////////////
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
&torchaudio::sox_effects::initialize_sox_effects);
m.def(
"torchaudio::sox_effects_shutdown_sox_effects",
&torchaudio::sox_effects::shutdown_sox_effects);
m.def(
"torchaudio::sox_effects_apply_effects_tensor",
&torchaudio::sox_effects::apply_effects_tensor);
m.def(
"torchaudio::sox_effects_apply_effects_file",
&torchaudio::sox_effects::apply_effects_file);
}
...@@ -61,24 +61,6 @@ std::vector<std::string> list_read_formats() { ...@@ -61,24 +61,6 @@ std::vector<std::string> list_read_formats() {
return formats; return formats;
} }
TensorSignal::TensorSignal(
torch::Tensor tensor_,
int64_t sample_rate_,
bool channels_first_)
: tensor(tensor_),
sample_rate(sample_rate_),
channels_first(channels_first_){};
torch::Tensor TensorSignal::getTensor() const {
return tensor;
}
int64_t TensorSignal::getSampleRate() const {
return sample_rate;
}
bool TensorSignal::getChannelsFirst() const {
return channels_first;
}
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {} SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
SoxFormat::~SoxFormat() { SoxFormat::~SoxFormat() {
close(); close();
...@@ -297,15 +279,16 @@ unsigned get_precision( ...@@ -297,15 +279,16 @@ unsigned get_precision(
} }
sox_signalinfo_t get_signalinfo( sox_signalinfo_t get_signalinfo(
const TensorSignal* signal, const torch::Tensor* waveform,
const std::string filetype) { const int64_t sample_rate,
auto tensor = signal->getTensor(); const std::string filetype,
const bool channels_first) {
return sox_signalinfo_t{ return sox_signalinfo_t{
/*rate=*/static_cast<sox_rate_t>(signal->getSampleRate()), /*rate=*/static_cast<sox_rate_t>(sample_rate),
/*channels=*/ /*channels=*/
static_cast<unsigned>(tensor.size(signal->getChannelsFirst() ? 0 : 1)), static_cast<unsigned>(waveform->size(channels_first ? 0 : 1)),
/*precision=*/get_precision(filetype, tensor.dtype()), /*precision=*/get_precision(filetype, waveform->dtype()),
/*length=*/static_cast<uint64_t>(tensor.numel())}; /*length=*/static_cast<uint64_t>(waveform->numel())};
} }
sox_encodinginfo_t get_encodinginfo( sox_encodinginfo_t get_encodinginfo(
...@@ -364,5 +347,27 @@ uint64_t read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) { ...@@ -364,5 +347,27 @@ uint64_t read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) {
#endif // TORCH_API_INCLUDE_EXTENSION_H #endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed);
m.def(
"torchaudio::sox_utils_set_verbosity",
&torchaudio::sox_utils::set_verbosity);
m.def(
"torchaudio::sox_utils_set_use_threads",
&torchaudio::sox_utils::set_use_threads);
m.def(
"torchaudio::sox_utils_set_buffer_size",
&torchaudio::sox_utils::set_buffer_size);
m.def(
"torchaudio::sox_utils_list_effects",
&torchaudio::sox_utils::list_effects);
m.def(
"torchaudio::sox_utils_list_read_formats",
&torchaudio::sox_utils::list_read_formats);
m.def(
"torchaudio::sox_utils_list_write_formats",
&torchaudio::sox_utils::list_write_formats);
}
} // namespace sox_utils } // namespace sox_utils
} // namespace torchaudio } // namespace torchaudio
...@@ -30,23 +30,6 @@ std::vector<std::string> list_read_formats(); ...@@ -30,23 +30,6 @@ std::vector<std::string> list_read_formats();
std::vector<std::string> list_write_formats(); std::vector<std::string> list_write_formats();
/// Class for exchanging signal infomation (tensor + meta data) between
/// C++ and Python for read/write operation.
struct TensorSignal : torch::CustomClassHolder {
torch::Tensor tensor;
int64_t sample_rate;
bool channels_first;
TensorSignal(
torch::Tensor tensor_,
int64_t sample_rate_,
bool channels_first_);
torch::Tensor getTensor() const;
int64_t getSampleRate() const;
bool getChannelsFirst() const;
};
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Utilities for sox_io / sox_effects implementations // Utilities for sox_io / sox_effects implementations
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -120,8 +103,10 @@ const std::string get_filetype(const std::string path); ...@@ -120,8 +103,10 @@ const std::string get_filetype(const std::string path);
/// Get sox_signalinfo_t for passing a torch::Tensor object. /// Get sox_signalinfo_t for passing a torch::Tensor object.
sox_signalinfo_t get_signalinfo( sox_signalinfo_t get_signalinfo(
const TensorSignal* signal, const torch::Tensor* waveform,
const std::string filetype); const int64_t sample_rate,
const std::string filetype,
const bool channels_first);
/// Get sox_encofinginfo_t for saving audoi file /// Get sox_encofinginfo_t for saving audoi file
sox_encodinginfo_t get_encodinginfo( sox_encodinginfo_t get_encodinginfo(
......
...@@ -63,7 +63,7 @@ def apply_effects_tensor( ...@@ -63,7 +63,7 @@ def apply_effects_tensor(
Note: Note:
This function works in the way very similar to ``sox`` command, however there are slight This function works in the way very similar to ``sox`` command, however there are slight
differences. For example, ``sox`` commnad adds certain effects automatically (such as differences. For example, ``sox`` command adds certain effects automatically (such as
``rate`` effect after ``speed`` and ``pitch`` and other effects), but this function does ``rate`` effect after ``speed`` and ``pitch`` and other effects), but this function does
only applies the given effects. (Therefore, to actually apply ``speed`` effect, you also only applies the given effects. (Therefore, to actually apply ``speed`` effect, you also
need to give ``rate`` effect with desired sampling rate.) need to give ``rate`` effect with desired sampling rate.)
...@@ -149,9 +149,8 @@ def apply_effects_tensor( ...@@ -149,9 +149,8 @@ def apply_effects_tensor(
>>> waveform, sample_rate = transform(waveform, input_sample_rate) >>> waveform, sample_rate = transform(waveform, input_sample_rate)
>>> assert sample_rate == 8000 >>> assert sample_rate == 8000
""" """
in_signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first) return torch.ops.torchaudio.sox_effects_apply_effects_tensor(
out_signal = torch.ops.torchaudio.sox_effects_apply_effects_tensor(in_signal, effects) tensor, sample_rate, effects, channels_first)
return out_signal.get_tensor(), out_signal.get_sample_rate()
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
...@@ -268,6 +267,5 @@ def apply_effects_file( ...@@ -268,6 +267,5 @@ def apply_effects_file(
return torchaudio._torchaudio.apply_effects_fileobj( return torchaudio._torchaudio.apply_effects_fileobj(
path, effects, normalize, channels_first, format) path, effects, normalize, channels_first, format)
path = os.fspath(path) path = os.fspath(path)
signal = torch.ops.torchaudio.sox_effects_apply_effects_file( return torch.ops.torchaudio.sox_effects_apply_effects_file(
path, effects, normalize, channels_first, format) path, effects, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment