"tests/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "81f8ba7952712c80e260cdb80e3646caf0a94811"
Unverified Commit 1fe0a40c authored by moto's avatar moto Committed by GitHub
Browse files

Clean up handling of optional args in C++ with c10:optional (#1043)

parent fb3ef9ba
......@@ -91,8 +91,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
c10::intrusive_ptr<TensorSignal> apply_effects_file(
const std::string path,
std::vector<std::vector<std::string>> effects,
const bool normalize,
const bool channels_first) {
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first) {
// Open input file
SoxFormat sf(sox_open_read(
path.c_str(),
......@@ -121,16 +121,17 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize,
channels_first);
normalize.value_or(true),
channels_first_);
return c10::make_intrusive<TensorSignal>(
tensor, chain.getOutputSampleRate(), channels_first);
tensor, chain.getOutputSampleRate(), channels_first_);
}
} // namespace sox_effects
......
......@@ -18,8 +18,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_tensor(
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file(
const std::string path,
std::vector<std::vector<std::string>> effects,
const bool normalize = true,
const bool channels_first = true);
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first);
} // namespace sox_effects
} // namespace torchaudio
......
......@@ -49,30 +49,32 @@ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
c10::intrusive_ptr<TensorSignal> load_audio_file(
const std::string& path,
const int64_t frame_offset,
const int64_t num_frames,
const bool normalize,
const bool channels_first) {
if (frame_offset < 0) {
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
if (num_frames == 0 || num_frames < -1) {
const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
std::vector<std::vector<std::string>> effects;
if (num_frames != -1) {
std::ostringstream offset, frames;
offset << frame_offset << "s";
frames << "+" << num_frames << "s";
if (frames != -1) {
std::ostringstream os_offset, os_frames;
os_offset << offset << "s";
os_frames << "+" << frames << "s";
effects.emplace_back(
std::vector<std::string>{"trim", offset.str(), frames.str()});
} else if (frame_offset != 0) {
std::ostringstream offset;
offset << frame_offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", offset.str()});
std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (offset != 0) {
std::ostringstream os_offset;
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return torchaudio::sox_effects::apply_effects_file(
......
......@@ -25,10 +25,10 @@ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const std::string& path,
const int64_t frame_offset = 0,
const int64_t num_frames = -1,
const bool normalize = true,
const bool channels_first = true);
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first);
void save_audio_file(
const std::string& file_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment