Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
e7161acf
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "94566e6dd8b018726b215f70e818589ac9815830"
Unverified
Commit
e7161acf
authored
Sep 18, 2020
by
Tim Loderhose
Committed by
GitHub
Sep 18, 2020
Browse files
Add pathlib.Path support to sox_io backend (#907)
parent
4cdd8cad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
4 deletions
+15
-4
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+15
-4
No files found.
torchaudio/backend/sox_io_backend.py
View file @
e7161acf
...
@@ -24,11 +24,15 @@ def info(filepath: str) -> AudioMetaData:
...
@@ -24,11 +24,15 @@ def info(filepath: str) -> AudioMetaData:
"""Get signal information of an audio file.
"""Get signal information of an audio file.
Args:
Args:
filepath (str): Path to audio file
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is annotated as
``str`` for TorchScript compiler compatibility.
Returns:
Returns:
AudioMetaData: meta data of the given audio.
AudioMetaData: meta data of the given audio.
"""
"""
# Cast to str in case type is `pathlib.Path`
filepath
=
str
(
filepath
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
)
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
())
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
())
...
@@ -80,8 +84,9 @@ def load(
...
@@ -80,8 +84,9 @@ def load(
``[-1.0, 1.0]``.
``[-1.0, 1.0]``.
Args:
Args:
filepath (str):
filepath (str or pathlib.Path):
Path to audio file
Path to audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
frame_offset (int):
frame_offset (int):
Number of frames to skip before start reading data.
Number of frames to skip before start reading data.
num_frames (int):
num_frames (int):
...
@@ -105,6 +110,8 @@ def load(
...
@@ -105,6 +110,8 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
``[channel, time]`` else ``[time, channel]``.
"""
"""
# Cast to str in case type is `pathlib.Path`
filepath
=
str
(
filepath
)
signal
=
torch
.
ops
.
torchaudio
.
sox_io_load_audio_file
(
signal
=
torch
.
ops
.
torchaudio
.
sox_io_load_audio_file
(
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
)
filepath
,
frame_offset
,
num_frames
,
normalize
,
channels_first
)
return
signal
.
get_tensor
(),
signal
.
get_sample_rate
()
return
signal
.
get_tensor
(),
signal
.
get_sample_rate
()
...
@@ -140,7 +147,9 @@ def save(
...
@@ -140,7 +147,9 @@ def save(
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
Args:
Args:
filepath (str): Path to save file.
filepath (str or pathlib.Path):
Path to save file. This function also handles ``pathlib.Path`` objects, but is annotated
as ``str`` for TorchScript compiler compatibility.
tensor (torch.Tensor): Audio data to save. must be 2D tensor.
tensor (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
sample_rate (int): sampling rate
channels_first (bool):
channels_first (bool):
...
@@ -158,6 +167,8 @@ def save(
...
@@ -158,6 +167,8 @@ def save(
See the detail at http://sox.sourceforge.net/soxformat.html.
See the detail at http://sox.sourceforge.net/soxformat.html.
"""
"""
# Cast to str in case type is `pathlib.Path`
filepath
=
str
(
filepath
)
if
compression
is
None
:
if
compression
is
None
:
ext
=
str
(
filepath
).
split
(
'.'
)[
-
1
].
lower
()
ext
=
str
(
filepath
).
split
(
'.'
)[
-
1
].
lower
()
if
ext
in
[
'wav'
,
'sph'
]:
if
ext
in
[
'wav'
,
'sph'
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment