"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "94566e6dd8b018726b215f70e818589ac9815830"
Unverified Commit e7161acf authored by Tim Loderhose's avatar Tim Loderhose Committed by GitHub
Browse files

Add pathlib.Path support to sox_io backend (#907)

parent 4cdd8cad
...@@ -24,11 +24,15 @@ def info(filepath: str) -> AudioMetaData: ...@@ -24,11 +24,15 @@ def info(filepath: str) -> AudioMetaData:
"""Get signal information of an audio file. """Get signal information of an audio file.
Args: Args:
filepath (str): Path to audio file filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is annotated as
``str`` for TorchScript compiler compatibility.
Returns: Returns:
AudioMetaData: meta data of the given audio. AudioMetaData: meta data of the given audio.
""" """
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels()) return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
...@@ -80,8 +84,9 @@ def load( ...@@ -80,8 +84,9 @@ def load(
``[-1.0, 1.0]``. ``[-1.0, 1.0]``.
Args: Args:
filepath (str): filepath (str or pathlib.Path):
Path to audio file Path to audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
frame_offset (int): frame_offset (int):
Number of frames to skip before start reading data. Number of frames to skip before start reading data.
num_frames (int): num_frames (int):
...@@ -105,6 +110,8 @@ def load( ...@@ -105,6 +110,8 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``. ``[channel, time]`` else ``[time, channel]``.
""" """
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
signal = torch.ops.torchaudio.sox_io_load_audio_file( signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first) filepath, frame_offset, num_frames, normalize, channels_first)
return signal.get_tensor(), signal.get_sample_rate() return signal.get_tensor(), signal.get_sample_rate()
...@@ -140,7 +147,9 @@ def save( ...@@ -140,7 +147,9 @@ def save(
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc. and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
Args: Args:
filepath (str): Path to save file. filepath (str or pathlib.Path):
Path to save file. This function also handles ``pathlib.Path`` objects, but is annotated
as ``str`` for TorchScript compiler compatibility.
tensor (torch.Tensor): Audio data to save. must be 2D tensor. tensor (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate sample_rate (int): sampling rate
channels_first (bool): channels_first (bool):
...@@ -158,6 +167,8 @@ def save( ...@@ -158,6 +167,8 @@ def save(
See the detail at http://sox.sourceforge.net/soxformat.html. See the detail at http://sox.sourceforge.net/soxformat.html.
""" """
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if compression is None: if compression is None:
ext = str(filepath).split('.')[-1].lower() ext = str(filepath).split('.')[-1].lower()
if ext in ['wav', 'sph']: if ext in ['wav', 'sph']:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment