Unverified Commit e344e45f authored by moto's avatar moto Committed by GitHub
Browse files

Simplify C++ registration with TORCH_LIBRARY (#840)

parent 6aafbb6d
...@@ -5,86 +5,70 @@ ...@@ -5,86 +5,70 @@
#include <torchaudio/csrc/sox_io.h> #include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h> #include <torchaudio/csrc/sox_utils.h>
namespace torchaudio { TORCH_LIBRARY(torchaudio, m) {
namespace { //////////////////////////////////////////////////////////////////////////////
// sox_utils.h
//////////////////////////////////////////////////////////////////////////////
m.class_<torchaudio::sox_utils::TensorSignal>("TensorSignal")
.def(torch::init<torch::Tensor, int64_t, bool>())
.def("get_tensor", &torchaudio::sox_utils::TensorSignal::getTensor)
.def(
"get_sample_rate",
&torchaudio::sox_utils::TensorSignal::getSampleRate)
.def(
"get_channels_first",
&torchaudio::sox_utils::TensorSignal::getChannelsFirst);
//////////////////////////////////////////////////////////////////////////////// m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed);
// sox_utils.h m.def(
//////////////////////////////////////////////////////////////////////////////// "torchaudio::sox_utils_set_verbosity",
static auto registerTensorSignal = &torchaudio::sox_utils::set_verbosity);
torch::class_<sox_utils::TensorSignal>("torchaudio", "TensorSignal") m.def(
.def(torch::init<torch::Tensor, int64_t, bool>()) "torchaudio::sox_utils_set_use_threads",
.def("get_tensor", &sox_utils::TensorSignal::getTensor) &torchaudio::sox_utils::set_use_threads);
.def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate) m.def(
.def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst); "torchaudio::sox_utils_set_buffer_size",
&torchaudio::sox_utils::set_buffer_size);
m.def(
"torchaudio::sox_utils_list_effects",
&torchaudio::sox_utils::list_effects);
m.def(
"torchaudio::sox_utils_list_read_formats",
&torchaudio::sox_utils::list_read_formats);
m.def(
"torchaudio::sox_utils_list_write_formats",
&torchaudio::sox_utils::list_write_formats);
static auto registerSetSoxOptions = //////////////////////////////////////////////////////////////////////////////
torch::RegisterOperators() // sox_io.h
.op("torchaudio::sox_utils_set_seed", &sox_utils::set_seed) //////////////////////////////////////////////////////////////////////////////
.op("torchaudio::sox_utils_set_verbosity", &sox_utils::set_verbosity) m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.op("torchaudio::sox_utils_set_use_threads", .def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
&sox_utils::set_use_threads) .def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.op("torchaudio::sox_utils_set_buffer_size", .def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames);
&sox_utils::set_buffer_size)
.op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects)
.op("torchaudio::sox_utils_list_read_formats",
&sox_utils::list_read_formats)
.op("torchaudio::sox_utils_list_write_formats",
&sox_utils::list_write_formats);
//////////////////////////////////////////////////////////////////////////////// m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
// sox_io.h m.def(
//////////////////////////////////////////////////////////////////////////////// "torchaudio::sox_io_load_audio_file",
static auto registerSignalInfo = &torchaudio::sox_io::load_audio_file);
torch::class_<sox_io::SignalInfo>("torchaudio", "SignalInfo") m.def(
.def("get_sample_rate", &sox_io::SignalInfo::getSampleRate) "torchaudio::sox_io_save_audio_file",
.def("get_num_channels", &sox_io::SignalInfo::getNumChannels) &torchaudio::sox_io::save_audio_file);
.def("get_num_frames", &sox_io::SignalInfo::getNumFrames);
static auto registerGetInfo = torch::RegisterOperators().op( //////////////////////////////////////////////////////////////////////////////
torch::RegisterOperators::options() // sox_effects.h
.schema( //////////////////////////////////////////////////////////////////////////////
"torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info") m.def(
.catchAllKernel<decltype(sox_io::get_info), &sox_io::get_info>()); "torchaudio::sox_effects_initialize_sox_effects",
&torchaudio::sox_effects::initialize_sox_effects);
static auto registerLoadAudioFile = torch::RegisterOperators().op( m.def(
torch::RegisterOperators::options() "torchaudio::sox_effects_shutdown_sox_effects",
.schema( &torchaudio::sox_effects::shutdown_sox_effects);
"torchaudio::sox_io_load_audio_file(str path, int frame_offset, int num_frames, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal signal") m.def(
.catchAllKernel< "torchaudio::sox_effects_apply_effects_tensor",
decltype(sox_io::load_audio_file), &torchaudio::sox_effects::apply_effects_tensor);
&sox_io::load_audio_file>()); m.def(
"torchaudio::sox_effects_apply_effects_file",
static auto registerSaveAudioFile = torch::RegisterOperators().op( &torchaudio::sox_effects::apply_effects_file);
torch::RegisterOperators::options() }
.schema(
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression) -> ()")
.catchAllKernel<
decltype(sox_io::save_audio_file),
&sox_io::save_audio_file>());
////////////////////////////////////////////////////////////////////////////////
// sox_effects.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSoxEffects =
torch::RegisterOperators()
.op("torchaudio::sox_effects_initialize_sox_effects",
&sox_effects::initialize_sox_effects)
.op("torchaudio::sox_effects_shutdown_sox_effects",
&sox_effects::shutdown_sox_effects)
.op(torch::RegisterOperators::options()
.schema(
"torchaudio::sox_effects_apply_effects_tensor(__torch__.torch.classes.torchaudio.TensorSignal input_signal, str[][] effects) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal")
.catchAllKernel<
decltype(sox_effects::apply_effects_tensor),
&sox_effects::apply_effects_tensor>())
.op(torch::RegisterOperators::options()
.schema(
"torchaudio::sox_effects_apply_effects_file(str path, str[][] effects, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal")
.catchAllKernel<
decltype(sox_effects::apply_effects_file),
&sox_effects::apply_effects_file>());
} // namespace
} // namespace torchaudio
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment