Unverified Commit e7161acf authored by Tim Loderhose's avatar Tim Loderhose Committed by GitHub
Browse files

Add pathlib.Path support to sox_io backend (#907)

parent 4cdd8cad
......@@ -24,11 +24,15 @@ def info(filepath: str) -> AudioMetaData:
"""Get signal information of an audio file.
Args:
filepath (str): Path to audio file
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is annotated as
``str`` for TorchScript compiler compatibility.
Returns:
AudioMetaData: meta data of the given audio.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
......@@ -80,8 +84,9 @@ def load(
``[-1.0, 1.0]``.
Args:
filepath (str):
Path to audio file
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
......@@ -105,6 +110,8 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first)
return signal.get_tensor(), signal.get_sample_rate()
......@@ -140,7 +147,9 @@ def save(
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
Args:
filepath (str): Path to save file.
filepath (str or pathlib.Path):
Path to save file. This function also handles ``pathlib.Path`` objects, but is annotated
as ``str`` for TorchScript compiler compatibility.
tensor (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
channels_first (bool):
......@@ -158,6 +167,8 @@ def save(
See the detail at http://sox.sourceforge.net/soxformat.html.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if compression is None:
ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment