Unverified Commit e344e45f authored by moto's avatar moto Committed by GitHub
Browse files

Simplify C++ registration with TORCH_LIBRARY (#840)

parent 6aafbb6d
...@@ -5,86 +5,70 @@ ...@@ -5,86 +5,70 @@
#include <torchaudio/csrc/sox_io.h> #include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h> #include <torchaudio/csrc/sox_utils.h>
namespace torchaudio { TORCH_LIBRARY(torchaudio, m) {
namespace { //////////////////////////////////////////////////////////////////////////////
// sox_utils.h
//////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// sox_utils.h m.class_<torchaudio::sox_utils::TensorSignal>("TensorSignal")
////////////////////////////////////////////////////////////////////////////////
static auto registerTensorSignal =
torch::class_<sox_utils::TensorSignal>("torchaudio", "TensorSignal")
.def(torch::init<torch::Tensor, int64_t, bool>()) .def(torch::init<torch::Tensor, int64_t, bool>())
.def("get_tensor", &sox_utils::TensorSignal::getTensor) .def("get_tensor", &torchaudio::sox_utils::TensorSignal::getTensor)
.def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate) .def(
.def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst); "get_sample_rate",
&torchaudio::sox_utils::TensorSignal::getSampleRate)
static auto registerSetSoxOptions = .def(
torch::RegisterOperators() "get_channels_first",
.op("torchaudio::sox_utils_set_seed", &sox_utils::set_seed) &torchaudio::sox_utils::TensorSignal::getChannelsFirst);
.op("torchaudio::sox_utils_set_verbosity", &sox_utils::set_verbosity)
.op("torchaudio::sox_utils_set_use_threads",
&sox_utils::set_use_threads)
.op("torchaudio::sox_utils_set_buffer_size",
&sox_utils::set_buffer_size)
.op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects)
.op("torchaudio::sox_utils_list_read_formats",
&sox_utils::list_read_formats)
.op("torchaudio::sox_utils_list_write_formats",
&sox_utils::list_write_formats);
////////////////////////////////////////////////////////////////////////////////
// sox_io.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo =
torch::class_<sox_io::SignalInfo>("torchaudio", "SignalInfo")
.def("get_sample_rate", &sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &sox_io::SignalInfo::getNumFrames);
static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info")
.catchAllKernel<decltype(sox_io::get_info), &sox_io::get_info>());
static auto registerLoadAudioFile = torch::RegisterOperators().op( m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed);
torch::RegisterOperators::options() m.def(
.schema( "torchaudio::sox_utils_set_verbosity",
"torchaudio::sox_io_load_audio_file(str path, int frame_offset, int num_frames, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal signal") &torchaudio::sox_utils::set_verbosity);
.catchAllKernel< m.def(
decltype(sox_io::load_audio_file), "torchaudio::sox_utils_set_use_threads",
&sox_io::load_audio_file>()); &torchaudio::sox_utils::set_use_threads);
m.def(
"torchaudio::sox_utils_set_buffer_size",
&torchaudio::sox_utils::set_buffer_size);
m.def(
"torchaudio::sox_utils_list_effects",
&torchaudio::sox_utils::list_effects);
m.def(
"torchaudio::sox_utils_list_read_formats",
&torchaudio::sox_utils::list_read_formats);
m.def(
"torchaudio::sox_utils_list_write_formats",
&torchaudio::sox_utils::list_write_formats);
static auto registerSaveAudioFile = torch::RegisterOperators().op( //////////////////////////////////////////////////////////////////////////////
torch::RegisterOperators::options() // sox_io.h
.schema( //////////////////////////////////////////////////////////////////////////////
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression) -> ()") m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.catchAllKernel< .def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
decltype(sox_io::save_audio_file), .def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
&sox_io::save_audio_file>()); .def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames);
//////////////////////////////////////////////////////////////////////////////// m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
// sox_effects.h m.def(
//////////////////////////////////////////////////////////////////////////////// "torchaudio::sox_io_load_audio_file",
static auto registerSoxEffects = &torchaudio::sox_io::load_audio_file);
torch::RegisterOperators() m.def(
.op("torchaudio::sox_effects_initialize_sox_effects", "torchaudio::sox_io_save_audio_file",
&sox_effects::initialize_sox_effects) &torchaudio::sox_io::save_audio_file);
.op("torchaudio::sox_effects_shutdown_sox_effects",
&sox_effects::shutdown_sox_effects)
.op(torch::RegisterOperators::options()
.schema(
"torchaudio::sox_effects_apply_effects_tensor(__torch__.torch.classes.torchaudio.TensorSignal input_signal, str[][] effects) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal")
.catchAllKernel<
decltype(sox_effects::apply_effects_tensor),
&sox_effects::apply_effects_tensor>())
.op(torch::RegisterOperators::options()
.schema(
"torchaudio::sox_effects_apply_effects_file(str path, str[][] effects, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal")
.catchAllKernel<
decltype(sox_effects::apply_effects_file),
&sox_effects::apply_effects_file>());
} // namespace //////////////////////////////////////////////////////////////////////////////
} // namespace torchaudio // sox_effects.h
//////////////////////////////////////////////////////////////////////////////
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
&torchaudio::sox_effects::initialize_sox_effects);
m.def(
"torchaudio::sox_effects_shutdown_sox_effects",
&torchaudio::sox_effects::shutdown_sox_effects);
m.def(
"torchaudio::sox_effects_apply_effects_tensor",
&torchaudio::sox_effects::apply_effects_tensor);
m.def(
"torchaudio::sox_effects_apply_effects_file",
&torchaudio::sox_effects::apply_effects_file);
}
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment