Unverified Commit 4b3e9052 authored by moto's avatar moto Committed by GitHub
Browse files

Remove frames_per_chunk argument from save (#780)

In #779, we plan to remove `frames_per_chunk` parameter from `save` function, but it will take some time before we can land #779, so we go ahead and remove the parameter first to reduce the conflict caused by interface change.
parent f6dc2f67
...@@ -87,7 +87,6 @@ def save( ...@@ -87,7 +87,6 @@ def save(
sample_rate: int, sample_rate: int,
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
frames_per_chunk: int = 65536,
): ):
"""Save audio data to file. """Save audio data to file.
...@@ -115,8 +114,6 @@ def save( ...@@ -115,8 +114,6 @@ def save(
``8`` is default and highest compression. ``8`` is default and highest compression.
- OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest - OGG/VORBIS: number from -1 to 10; -1 is the highest compression and lowest
quality. Default: ``3``. quality. Default: ``3``.
frames_per_chunk: The number of frames to process (convert to ``int32`` internally
then write to file) at a time.
""" """
if compression is None: if compression is None:
ext = str(filepath)[-3:].lower() ext = str(filepath)[-3:].lower()
...@@ -131,7 +128,7 @@ def save( ...@@ -131,7 +128,7 @@ def save(
else: else:
raise RuntimeError(f'Unsupported file type: "{ext}"') raise RuntimeError(f'Unsupported file type: "{ext}"')
signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first) signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first)
torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression, frames_per_chunk) torch.ops.torchaudio.sox_io_save_audio_file(filepath, signal, compression)
load_wav = load load_wav = load
...@@ -44,7 +44,7 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op( ...@@ -44,7 +44,7 @@ static auto registerLoadAudioFile = torch::RegisterOperators().op(
static auto registerSaveAudioFile = torch::RegisterOperators().op( static auto registerSaveAudioFile = torch::RegisterOperators().op(
torch::RegisterOperators::options() torch::RegisterOperators::options()
.schema( .schema(
"torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression, int frames_per_chunk) -> ()") "torchaudio::sox_io_save_audio_file(str path, __torch__.torch.classes.torchaudio.TensorSignal signal, float compression) -> ()")
.catchAllKernel< .catchAllKernel<
decltype(sox_io::save_audio_file), decltype(sox_io::save_audio_file),
&sox_io::save_audio_file>()); &sox_io::save_audio_file>());
......
...@@ -123,8 +123,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file( ...@@ -123,8 +123,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
void save_audio_file( void save_audio_file(
const std::string& file_name, const std::string& file_name,
const c10::intrusive_ptr<TensorSignal>& signal, const c10::intrusive_ptr<TensorSignal>& signal,
const double compression, const double compression) {
const int64_t frames_per_chunk) {
const auto tensor = signal->getTensor(); const auto tensor = signal->getTensor();
const auto sample_rate = signal->getSampleRate(); const auto sample_rate = signal->getSampleRate();
const auto channels_first = signal->getChannelsFirst(); const auto channels_first = signal->getChannelsFirst();
...@@ -154,6 +153,7 @@ void save_audio_file( ...@@ -154,6 +153,7 @@ void save_audio_file(
tensor_ = tensor_.t(); tensor_ = tensor_.t();
} }
const int64_t frames_per_chunk = 65536;
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) { for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()}); auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
chunk = unnormalize_wav(chunk).contiguous(); chunk = unnormalize_wav(chunk).contiguous();
......
...@@ -33,8 +33,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file( ...@@ -33,8 +33,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
void save_audio_file( void save_audio_file(
const std::string& file_name, const std::string& file_name,
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal, const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
const double compression = 0., const double compression = 0.);
const int64_t frames_per_chunk = 65536);
} // namespace sox_io } // namespace sox_io
} // namespace torchaudio } // namespace torchaudio
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment