"vscode:/vscode.git/clone" did not exist on "854ece55e7a5f5b6a815359d6f2e10892d0f40c9"
Commit 71ddee16 authored by hwangjeff's avatar hwangjeff
Browse files

Make buffer size configurable in ffmpeg file object operations and set size in backend

parent 89e28623
...@@ -4,6 +4,7 @@ from typing import Optional, Tuple ...@@ -4,6 +4,7 @@ from typing import Optional, Tuple
import torch import torch
import torchaudio import torchaudio
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.utils.sox_utils import get_buffer_size
from .common import AudioMetaData from .common import AudioMetaData
...@@ -91,12 +92,13 @@ def info( ...@@ -91,12 +92,13 @@ def info(
# The previous libsox-based implementation required `format="mp3"` # The previous libsox-based implementation required `format="mp3"`
# because internally libsox does not auto-detect the format. # because internally libsox does not auto-detect the format.
# For the special BC for mp3, we handle mp3 differently. # For the special BC for mp3, we handle mp3 differently.
buffer_size = get_buffer_size()
if format == "mp3": if format == "mp3":
return _fallback_info_fileobj(filepath, format) return _fallback_info_fileobj(filepath, format, buffer_size)
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format) sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
if sinfo is not None: if sinfo is not None:
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format) return _fallback_info_fileobj(filepath, format, buffer_size)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
if sinfo is not None: if sinfo is not None:
...@@ -210,14 +212,31 @@ def load( ...@@ -210,14 +212,31 @@ def load(
# The previous libsox-based implementation required `format="mp3"` # The previous libsox-based implementation required `format="mp3"`
# because internally libsox does not auto-detect the format. # because internally libsox does not auto-detect the format.
# For the special BC for mp3, we handle mp3 differently. # For the special BC for mp3, we handle mp3 differently.
buffer_size = get_buffer_size()
if format == "mp3": if format == "mp3":
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format) return _fallback_load_fileobj(
filepath,
frame_offset,
num_frames,
normalize,
channels_first,
format,
buffer_size,
)
ret = torchaudio._torchaudio.load_audio_fileobj( ret = torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format filepath, frame_offset, num_frames, normalize, channels_first, format
) )
if ret is not None: if ret is not None:
return ret return ret
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format) return _fallback_load_fileobj(
filepath,
frame_offset,
num_frames,
normalize,
channels_first,
format,
buffer_size,
)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
ret = torch.ops.torchaudio.sox_io_load_audio_file( ret = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format filepath, frame_offset, num_frames, normalize, channels_first, format
...@@ -385,10 +404,24 @@ def save( ...@@ -385,10 +404,24 @@ def save(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(filepath, "write"): if hasattr(filepath, "write"):
torchaudio._torchaudio.save_audio_fileobj( torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample filepath,
src,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample,
) )
return return
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file( torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample filepath,
src,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample,
) )
...@@ -36,8 +36,9 @@ def info_audio( ...@@ -36,8 +36,9 @@ def info_audio(
def info_audio_fileobj( def info_audio_fileobj(
src, src,
format: Optional[str], format: Optional[str],
buffer_size: int = 4096,
) -> AudioMetaData: ) -> AudioMetaData:
s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
return _info_audio(s) return _info_audio(s)
...@@ -110,6 +111,7 @@ def load_audio_fileobj( ...@@ -110,6 +111,7 @@ def load_audio_fileobj(
convert: bool = True, convert: bool = True,
channels_first: bool = True, channels_first: bool = True,
format: Optional[str] = None, format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
return _load_audio(s, frame_offset, num_frames, convert, channels_first) return _load_audio(s, frame_offset, num_frames, convert, channels_first)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment