Unverified Commit 4b3e9052 authored by moto's avatar moto Committed by GitHub
Browse files

Remove frames_per_chunk argument from save (#780)

In #779, we plan to remove `frames_per_chunk` parameter from `save` function, but it will take some time before we can land #779, so we go ahead and remove the parameter first to reduce the conflict caused by interface change.
parent f6dc2f67
......@@ -87,7 +87,6 @@ def save(
sample_rate: int,
channels_first: bool = True,
compression: Optional[float] = None,
frames_per_chunk: int = 65536,
):
"""Save audio data to file.
......@@ -115,8 +114,6 @@ def save(
``8`` is default and highest compression.
- OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest
quality. Default: ``3``.
frames_per_chunk: The number of frames to process (convert to ``int32`` internally
then write to file) at a time.
"""
if compression is None:
ext = str(filepath)[-3:].lower()
......@@ -131,7 +128,7 @@ def save(
else:
raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression, frames_per_chunk)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression)
load_wav = load
......@@ -44,7 +44,7 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op(
static auto registerSaveAudioFile = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression, int frames_per_chunk) -> ()")
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression) -> ()")
.catchAllKernel<
decltype(sox_io::save_audio_file),
&sox_io::save_audio_file>());
......
......@@ -123,8 +123,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal,
const double compression,
const int64_t frames_per_chunk) {
const double compression) {
const auto tensor = signal->getTensor();
const auto sample_rate = signal->getSampleRate();
const auto channels_first = signal->getChannelsFirst();
......@@ -154,6 +153,7 @@ void save_audio_file(
tensor_ = tensor_.t();
}
const int64_t frames_per_chunk = 65536;
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
chunk = unnormalize_wav(chunk).contiguous();
......
......@@ -33,8 +33,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
void save_audio_file(
const std::string& file_name,
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
const double compression = 0.,
const int64_t frames_per_chunk = 65536);
const double compression = 0.);
} // namespace sox_io
} // namespace torchaudio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment