"...text-generation-inference.git" did not exist on "e6bb3ff81fd670ad2f54904676f8165367dd47f8"
Unverified Commit 1fe0a40c authored by moto's avatar moto Committed by GitHub
Browse files

Clean up handling of optional args in C++ with c10:optional (#1043)

parent fb3ef9ba
...@@ -91,8 +91,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor( ...@@ -91,8 +91,8 @@ c10::intrusive_ptr<TensorSignal> apply_effects_tensor(
c10::intrusive_ptr<TensorSignal> apply_effects_file( c10::intrusive_ptr<TensorSignal> apply_effects_file(
const std::string path, const std::string path,
std::vector<std::vector<std::string>> effects, std::vector<std::vector<std::string>> effects,
const bool normalize, c10::optional<bool>& normalize,
const bool channels_first) { c10::optional<bool>& channels_first) {
// Open input file // Open input file
SoxFormat sf(sox_open_read( SoxFormat sf(sox_open_read(
path.c_str(), path.c_str(),
...@@ -121,16 +121,17 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file( ...@@ -121,16 +121,17 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
chain.run(); chain.run();
// Create tensor from buffer // Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor( auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(), /*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(), /*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(), /*num_channels=*/chain.getOutputNumChannels(),
dtype, dtype,
normalize, normalize.value_or(true),
channels_first); channels_first_);
return c10::make_intrusive<TensorSignal>( return c10::make_intrusive<TensorSignal>(
tensor, chain.getOutputSampleRate(), channels_first); tensor, chain.getOutputSampleRate(), channels_first_);
} }
} // namespace sox_effects } // namespace sox_effects
......
...@@ -18,8 +18,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_tensor( ...@@ -18,8 +18,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_tensor(
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file( c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file(
const std::string path, const std::string path,
std::vector<std::vector<std::string>> effects, std::vector<std::vector<std::string>> effects,
const bool normalize = true, c10::optional<bool>& normalize,
const bool channels_first = true); c10::optional<bool>& channels_first);
} // namespace sox_effects } // namespace sox_effects
} // namespace torchaudio } // namespace torchaudio
......
...@@ -49,30 +49,32 @@ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) { ...@@ -49,30 +49,32 @@ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
c10::intrusive_ptr<TensorSignal> load_audio_file( c10::intrusive_ptr<TensorSignal> load_audio_file(
const std::string& path, const std::string& path,
const int64_t frame_offset, c10::optional<int64_t>& frame_offset,
const int64_t num_frames, c10::optional<int64_t>& num_frames,
const bool normalize, c10::optional<bool>& normalize,
const bool channels_first) { c10::optional<bool>& channels_first) {
if (frame_offset < 0) { const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error( throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative."); "Invalid argument: frame_offset must be non-negative.");
} }
if (num_frames == 0 || num_frames < -1) { const auto frames = num_frames.value_or(-1);
if (frames == 0 || frames < -1) {
throw std::runtime_error( throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0."); "Invalid argument: num_frames must be -1 or greater than 0.");
} }
std::vector<std::vector<std::string>> effects; std::vector<std::vector<std::string>> effects;
if (num_frames != -1) { if (frames != -1) {
std::ostringstream offset, frames; std::ostringstream os_offset, os_frames;
offset << frame_offset << "s"; os_offset << offset << "s";
frames << "+" << num_frames << "s"; os_frames << "+" << frames << "s";
effects.emplace_back( effects.emplace_back(
std::vector<std::string>{"trim", offset.str(), frames.str()}); std::vector<std::string>{"trim", os_offset.str(), os_frames.str()});
} else if (frame_offset != 0) { } else if (offset != 0) {
std::ostringstream offset; std::ostringstream os_offset;
offset << frame_offset << "s"; os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", offset.str()}); effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
} }
return torchaudio::sox_effects::apply_effects_file( return torchaudio::sox_effects::apply_effects_file(
......
...@@ -25,10 +25,10 @@ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path); ...@@ -25,10 +25,10 @@ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file( c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const std::string& path, const std::string& path,
const int64_t frame_offset = 0, c10::optional<int64_t>& frame_offset,
const int64_t num_frames = -1, c10::optional<int64_t>& num_frames,
const bool normalize = true, c10::optional<bool>& normalize,
const bool channels_first = true); c10::optional<bool>& channels_first);
void save_audio_file( void save_audio_file(
const std::string& file_name, const std::string& file_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment