Commit 71ddee16 authored by hwangjeff's avatar hwangjeff
Browse files

Make buffer size configurable in ffmpeg file object operations and set size in backend

parent 89e28623
...@@ -4,6 +4,7 @@ from typing import Optional, Tuple ...@@ -4,6 +4,7 @@ from typing import Optional, Tuple
import torch import torch
import torchaudio import torchaudio
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from torchaudio.utils.sox_utils import get_buffer_size
from .common import AudioMetaData from .common import AudioMetaData
...@@ -91,12 +92,13 @@ def info( ...@@ -91,12 +92,13 @@ def info(
# The previous libsox-based implementation required `format="mp3"` # The previous libsox-based implementation required `format="mp3"`
# because internally libsox does not auto-detect the format. # because internally libsox does not auto-detect the format.
# For the special BC for mp3, we handle mp3 differently. # For the special BC for mp3, we handle mp3 differently.
buffer_size = get_buffer_size()
if format == "mp3": if format == "mp3":
return _fallback_info_fileobj(filepath, format) return _fallback_info_fileobj(filepath, format, buffer_size)
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format) sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
if sinfo is not None: if sinfo is not None:
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format) return _fallback_info_fileobj(filepath, format, buffer_size)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
if sinfo is not None: if sinfo is not None:
...@@ -210,14 +212,31 @@ def load( ...@@ -210,14 +212,31 @@ def load(
# The previous libsox-based implementation required `format="mp3"` # The previous libsox-based implementation required `format="mp3"`
# because internally libsox does not auto-detect the format. # because internally libsox does not auto-detect the format.
# For the special BC for mp3, we handle mp3 differently. # For the special BC for mp3, we handle mp3 differently.
buffer_size = get_buffer_size()
if format == "mp3": if format == "mp3":
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format) return _fallback_load_fileobj(
filepath,
frame_offset,
num_frames,
normalize,
channels_first,
format,
buffer_size,
)
ret = torchaudio._torchaudio.load_audio_fileobj( ret = torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format filepath, frame_offset, num_frames, normalize, channels_first, format
) )
if ret is not None: if ret is not None:
return ret return ret
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format) return _fallback_load_fileobj(
filepath,
frame_offset,
num_frames,
normalize,
channels_first,
format,
buffer_size,
)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
ret = torch.ops.torchaudio.sox_io_load_audio_file( ret = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format filepath, frame_offset, num_frames, normalize, channels_first, format
...@@ -385,10 +404,24 @@ def save( ...@@ -385,10 +404,24 @@ def save(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(filepath, "write"): if hasattr(filepath, "write"):
torchaudio._torchaudio.save_audio_fileobj( torchaudio._torchaudio.save_audio_fileobj(
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample filepath,
src,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample,
) )
return return
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file( torch.ops.torchaudio.sox_io_save_audio_file(
filepath, src, sample_rate, channels_first, compression, format, encoding, bits_per_sample filepath,
src,
sample_rate,
channels_first,
compression,
format,
encoding,
bits_per_sample,
) )
...@@ -36,8 +36,9 @@ def info_audio( ...@@ -36,8 +36,9 @@ def info_audio(
def info_audio_fileobj( def info_audio_fileobj(
src, src,
format: Optional[str], format: Optional[str],
buffer_size: int = 4096,
) -> AudioMetaData: ) -> AudioMetaData:
s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
return _info_audio(s) return _info_audio(s)
...@@ -110,6 +111,7 @@ def load_audio_fileobj( ...@@ -110,6 +111,7 @@ def load_audio_fileobj(
convert: bool = True, convert: bool = True,
channels_first: bool = True, channels_first: bool = True,
format: Optional[str] = None, format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, 4096) s = torchaudio._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
return _load_audio(s, frame_offset, num_frames, convert, channels_first) return _load_audio(s, frame_offset, num_frames, convert, channels_first)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment