Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
99ed7183
Unverified
Commit
99ed7183
authored
Jan 25, 2021
by
Nicolas Hug
Committed by
GitHub
Jan 25, 2021
Browse files
Add bits_per_sample to info (#1177)
parent
27031755
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
127 additions
and
20 deletions
+127
-20
test/torchaudio_unittest/soundfile_backend/common.py
test/torchaudio_unittest/soundfile_backend/common.py
+1
-1
test/torchaudio_unittest/soundfile_backend/info_test.py
test/torchaudio_unittest/soundfile_backend/info_test.py
+39
-7
test/torchaudio_unittest/sox_io_backend/info_test.py
test/torchaudio_unittest/sox_io_backend/info_test.py
+17
-4
torchaudio/backend/_soundfile_backend.py
torchaudio/backend/_soundfile_backend.py
+48
-1
torchaudio/backend/common.py
torchaudio/backend/common.py
+4
-1
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+2
-1
torchaudio/csrc/sox/io.cpp
torchaudio/csrc/sox/io.cpp
+10
-3
torchaudio/csrc/sox/io.h
torchaudio/csrc/sox/io.h
+4
-1
torchaudio/csrc/sox/register.cpp
torchaudio/csrc/sox/register.cpp
+2
-1
No files found.
test/torchaudio_unittest/soundfile_backend/common.py
View file @
99ed7183
...
@@ -26,7 +26,7 @@ def skipIfFormatNotSupported(fmt):
...
@@ -26,7 +26,7 @@ def skipIfFormatNotSupported(fmt):
import
soundfile
import
soundfile
fmts
=
soundfile
.
available_formats
()
fmts
=
soundfile
.
available_formats
()
return
skipIf
(
fmt
not
in
fmts
,
f
'"
{
fmt
}
" is not supported by sondfile'
)
return
skipIf
(
fmt
not
in
fmts
,
f
'"
{
fmt
}
" is not supported by so
u
ndfile'
)
return
skipIf
(
True
,
'"soundfile" not available.'
)
return
skipIf
(
True
,
'"soundfile" not available.'
)
...
...
test/torchaudio_unittest/soundfile_backend/info_test.py
View file @
99ed7183
from
unittest.mock
import
patch
import
warnings
import
torch
import
torch
from
torchaudio.backend
import
_soundfile_backend
as
soundfile_backend
from
torchaudio.backend
import
_soundfile_backend
as
soundfile_backend
from
torchaudio._internal
import
module_utils
as
_mod_utils
from
torchaudio._internal
import
module_utils
as
_mod_utils
...
@@ -18,10 +21,11 @@ if _mod_utils.is_module_available("soundfile"):
...
@@ -18,10 +21,11 @@ if _mod_utils.is_module_available("soundfile"):
@
skipIfNoModule
(
"soundfile"
)
@
skipIfNoModule
(
"soundfile"
)
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
class
TestInfo
(
TempDirMixin
,
PytorchTestCase
):
@
parameterize
(
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
1
,
2
],
[
(
"float32"
,
32
),
(
"int32"
,
32
),
(
"int16"
,
16
),
(
"uint8"
,
8
)
],
[
8000
,
16000
],
[
1
,
2
],
)
)
def
test_wav
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav
(
self
,
dtype
_and_bit_depth
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file correctly"""
"""`soundfile_backend.info` can check wav file correctly"""
dtype
,
bits_per_sample
=
dtype_and_bit_depth
duration
=
1
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
data
=
get_wav_data
(
...
@@ -32,12 +36,14 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -32,12 +36,14 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
@
parameterize
(
@
parameterize
(
[
"float32"
,
"int32"
,
"int16"
,
"uint8"
],
[
8000
,
16000
],
[
4
,
8
,
16
,
3
2
],
[
(
"float32"
,
32
),
(
"int32"
,
32
),
(
"int16"
,
16
),
(
"uint8"
,
8
)
],
[
8000
,
16000
],
[
1
,
2
],
)
)
def
test_wav_multiple_channels
(
self
,
dtype
,
sample_rate
,
num_channels
):
def
test_wav_multiple_channels
(
self
,
dtype
_and_bit_depth
,
sample_rate
,
num_channels
):
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
"""`soundfile_backend.info` can check wav file with channels more than 2 correctly"""
dtype
,
bits_per_sample
=
dtype_and_bit_depth
duration
=
1
duration
=
1
path
=
self
.
get_temp_path
(
"data.wav"
)
path
=
self
.
get_temp_path
(
"data.wav"
)
data
=
get_wav_data
(
data
=
get_wav_data
(
...
@@ -48,6 +54,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -48,6 +54,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"FLAC"
)
@
skipIfFormatNotSupported
(
"FLAC"
)
...
@@ -63,6 +70,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -63,6 +70,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
num_frames
assert
info
.
num_frames
==
num_frames
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
16
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
skipIfFormatNotSupported
(
"OGG"
)
@
skipIfFormatNotSupported
(
"OGG"
)
...
@@ -78,18 +86,42 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -78,18 +86,42 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
@
parameterize
([
8000
,
16000
],
[
1
,
2
])
@
parameterize
([
8000
,
16000
],
[
1
,
2
]
,
[(
'PCM_24'
,
24
),
(
'PCM_32'
,
32
)]
)
@
skipIfFormatNotSupported
(
"NIST"
)
@
skipIfFormatNotSupported
(
"NIST"
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
subtype_and_bit_depth
):
"""`soundfile_backend.info` can check sph file correctly"""
"""`soundfile_backend.info` can check sph file correctly"""
duration
=
1
duration
=
1
num_frames
=
sample_rate
*
duration
num_frames
=
sample_rate
*
duration
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
data
=
torch
.
randn
(
num_frames
,
num_channels
).
numpy
()
path
=
self
.
get_temp_path
(
"data.nist"
)
path
=
self
.
get_temp_path
(
"data.nist"
)
soundfile
.
write
(
path
,
data
,
sample_rate
)
subtype
,
bits_per_sample
=
subtype_and_bit_depth
soundfile
.
write
(
path
,
data
,
sample_rate
,
subtype
=
subtype
)
info
=
soundfile_backend
.
info
(
path
)
info
=
soundfile_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
def
test_unknown_subtype_warning
(
self
):
"""soundfile_backend.info issues a warning when the subtype is unknown
This will happen if a new subtype is supported in SoundFile: the _SUBTYPE_TO_BITS_PER_SAMPLE
dict should be updated.
"""
def
_mock_info_func
(
_
):
class
MockSoundFileInfo
:
samplerate
=
8000
frames
=
356
channels
=
2
subtype
=
'UNSEEN_SUBTYPE'
return
MockSoundFileInfo
()
with
patch
(
"soundfile.info"
,
_mock_info_func
):
with
warnings
.
catch_warnings
(
record
=
True
)
as
w
:
info
=
soundfile_backend
.
info
(
"foo"
)
assert
len
(
w
)
==
1
assert
"UNSEEN_SUBTYPE subtype is unknown to TorchAudio"
in
str
(
w
[
-
1
].
message
)
assert
info
.
bits_per_sample
==
0
test/torchaudio_unittest/sox_io_backend/info_test.py
View file @
99ed7183
...
@@ -36,6 +36,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -36,6 +36,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
@@ -52,6 +53,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -52,6 +53,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
sox_utils
.
get_bit_depth
(
dtype
)
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -71,6 +73,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -71,6 +73,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
# mp3 does not preserve the number of samples
# mp3 does not preserve the number of samples
# assert info.num_frames == sample_rate * duration
# assert info.num_frames == sample_rate * duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -89,6 +92,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -89,6 +92,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
24
# FLAC standard
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -107,20 +111,23 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -107,20 +111,23 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
[
1
,
2
],
[
1
,
2
],
[
16
,
32
],
)),
name_func
=
name_func
)
)),
name_func
=
name_func
)
def
test_sphere
(
self
,
sample_rate
,
num_channels
):
def
test_sphere
(
self
,
sample_rate
,
num_channels
,
bits_per_sample
):
"""`sox_io_backend.info` can check sph file correctly"""
"""`sox_io_backend.info` can check sph file correctly"""
duration
=
1
duration
=
1
path
=
self
.
get_temp_path
(
'data.sph'
)
path
=
self
.
get_temp_path
(
'data.sph'
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
duration
=
duration
)
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
duration
=
duration
,
bit_depth
=
bits_per_sample
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
@@ -131,13 +138,15 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -131,13 +138,15 @@ class TestInfo(TempDirMixin, PytorchTestCase):
"""`sox_io_backend.info` can check amb file correctly"""
"""`sox_io_backend.info` can check amb file correctly"""
duration
=
1
duration
=
1
path
=
self
.
get_temp_path
(
'data.amb'
)
path
=
self
.
get_temp_path
(
'data.amb'
)
bits_per_sample
=
sox_utils
.
get_bit_depth
(
dtype
)
sox_utils
.
gen_audio_file
(
sox_utils
.
gen_audio_file
(
path
,
sample_rate
,
num_channels
,
path
,
sample_rate
,
num_channels
,
bit_depth
=
sox_utils
.
get_bit_depth
(
dtype
)
,
duration
=
duration
)
bit_depth
=
bits_per_sample
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
bits_per_sample
def
test_amr_nb
(
self
):
def
test_amr_nb
(
self
):
"""`sox_io_backend.info` can check amr-nb file correctly"""
"""`sox_io_backend.info` can check amr-nb file correctly"""
...
@@ -146,11 +155,13 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -146,11 +155,13 @@ class TestInfo(TempDirMixin, PytorchTestCase):
sample_rate
=
8000
sample_rate
=
8000
path
=
self
.
get_temp_path
(
'data.amr-nb'
)
path
=
self
.
get_temp_path
(
'data.amr-nb'
)
sox_utils
.
gen_audio_file
(
sox_utils
.
gen_audio_file
(
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
16
,
duration
=
duration
)
path
,
sample_rate
=
sample_rate
,
num_channels
=
num_channels
,
bit_depth
=
16
,
duration
=
duration
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
sample_rate
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
@
skipIfNoExtension
@
skipIfNoExtension
...
@@ -167,6 +178,7 @@ class TestInfoOpus(PytorchTestCase):
...
@@ -167,6 +178,7 @@ class TestInfoOpus(PytorchTestCase):
assert
info
.
sample_rate
==
48000
assert
info
.
sample_rate
==
48000
assert
info
.
num_frames
==
32768
assert
info
.
num_frames
==
32768
assert
info
.
num_channels
==
num_channels
assert
info
.
num_channels
==
num_channels
assert
info
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
@
skipIfNoExtension
@
skipIfNoExtension
...
@@ -184,3 +196,4 @@ class TestLoadWithoutExtension(PytorchTestCase):
...
@@ -184,3 +196,4 @@ class TestLoadWithoutExtension(PytorchTestCase):
path
=
get_asset_path
(
"mp3_without_ext"
)
path
=
get_asset_path
(
"mp3_without_ext"
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
sinfo
=
sox_io_backend
.
info
(
path
,
format
=
"mp3"
)
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
sample_rate
==
16000
assert
sinfo
.
bits_per_sample
==
0
# bit_per_sample is irrelevant for compressed formats
torchaudio/backend/_soundfile_backend.py
View file @
99ed7183
...
@@ -11,6 +11,45 @@ if _mod_utils.is_module_available("soundfile"):
...
@@ -11,6 +11,45 @@ if _mod_utils.is_module_available("soundfile"):
import
soundfile
import
soundfile
# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
# For ADPCM (and G72X) subtypes, it's hard to infer the bit depth because it's not part of the standard:
# According to https://en.wikipedia.org/wiki/Adaptive_differential_pulse-code_modulation#In_telephony,
# the default seems to be 8 bits but it can be compressed further to 4 bits.
# The dict is inspired from
# https://github.com/bastibe/python-soundfile/blob/744efb4b01abc72498a96b09115b42a4cabd85e4/soundfile.py#L66-L94
_SUBTYPE_TO_BITS_PER_SAMPLE
=
{
'PCM_S8'
:
8
,
# Signed 8 bit data
'PCM_16'
:
16
,
# Signed 16 bit data
'PCM_24'
:
24
,
# Signed 24 bit data
'PCM_32'
:
32
,
# Signed 32 bit data
'PCM_U8'
:
8
,
# Unsigned 8 bit data (WAV and RAW only)
'FLOAT'
:
32
,
# 32 bit float data
'DOUBLE'
:
64
,
# 64 bit float data
'ULAW'
:
8
,
# U-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'ALAW'
:
8
,
# A-Law encoded. See https://en.wikipedia.org/wiki/G.711#Types
'IMA_ADPCM'
:
0
,
# IMA ADPCM.
'MS_ADPCM'
:
0
,
# Microsoft ADPCM.
'GSM610'
:
0
,
# GSM 6.10 encoding. (Wikipedia says 1.625 bit depth?? https://en.wikipedia.org/wiki/Full_Rate)
'VOX_ADPCM'
:
0
,
# OKI / Dialogix ADPCM
'G721_32'
:
0
,
# 32kbs G721 ADPCM encoding.
'G723_24'
:
0
,
# 24kbs G723 ADPCM encoding.
'G723_40'
:
0
,
# 40kbs G723 ADPCM encoding.
'DWVW_12'
:
12
,
# 12 bit Delta Width Variable Word encoding.
'DWVW_16'
:
16
,
# 16 bit Delta Width Variable Word encoding.
'DWVW_24'
:
24
,
# 24 bit Delta Width Variable Word encoding.
'DWVW_N'
:
0
,
# N bit Delta Width Variable Word encoding.
'DPCM_8'
:
8
,
# 8 bit differential PCM (XI only)
'DPCM_16'
:
16
,
# 16 bit differential PCM (XI only)
'VORBIS'
:
0
,
# Xiph Vorbis encoding. (lossy)
'ALAC_16'
:
16
,
# Apple Lossless Audio Codec (16 bit).
'ALAC_20'
:
20
,
# Apple Lossless Audio Codec (20 bit).
'ALAC_24'
:
24
,
# Apple Lossless Audio Codec (24 bit).
'ALAC_32'
:
32
,
# Apple Lossless Audio Codec (32 bit).
}
@
_mod_utils
.
requires_module
(
"soundfile"
)
@
_mod_utils
.
requires_module
(
"soundfile"
)
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
)
->
AudioMetaData
:
def
info
(
filepath
:
str
,
format
:
Optional
[
str
]
=
None
)
->
AudioMetaData
:
"""Get signal information of an audio file.
"""Get signal information of an audio file.
...
@@ -27,7 +66,15 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
...
@@ -27,7 +66,15 @@ def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
AudioMetaData: meta data of the given audio.
AudioMetaData: meta data of the given audio.
"""
"""
sinfo
=
soundfile
.
info
(
filepath
)
sinfo
=
soundfile
.
info
(
filepath
)
return
AudioMetaData
(
sinfo
.
samplerate
,
sinfo
.
frames
,
sinfo
.
channels
)
if
sinfo
.
subtype
not
in
_SUBTYPE_TO_BITS_PER_SAMPLE
:
warnings
.
warn
(
f
"The
{
sinfo
.
subtype
}
subtype is unknown to TorchAudio. As a result, the bits_per_sample "
"attribute will be set to 0. If you are seeing this warning, please "
"report by opening an issue on github (after checking for existing/closed ones). "
"You may otherwise ignore this warning."
)
bits_per_sample
=
_SUBTYPE_TO_BITS_PER_SAMPLE
.
get
(
sinfo
.
subtype
,
0
)
return
AudioMetaData
(
sinfo
.
samplerate
,
sinfo
.
frames
,
sinfo
.
channels
,
bits_per_sample
=
bits_per_sample
)
_SUBTYPE2DTYPE
=
{
_SUBTYPE2DTYPE
=
{
...
...
torchaudio/backend/common.py
View file @
99ed7183
...
@@ -12,11 +12,14 @@ class AudioMetaData:
...
@@ -12,11 +12,14 @@ class AudioMetaData:
:ivar int sample_rate: Sample rate
:ivar int sample_rate: Sample rate
:ivar int num_frames: The number of frames
:ivar int num_frames: The number of frames
:ivar int num_channels: The number of channels
:ivar int num_channels: The number of channels
:ivar int bits_per_sample: The number of bits per sample. This is 0 for lossy formats,
or when it cannot be accurately inferred.
"""
"""
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
):
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
,
bits_per_sample
:
int
):
self
.
sample_rate
=
sample_rate
self
.
sample_rate
=
sample_rate
self
.
num_frames
=
num_frames
self
.
num_frames
=
num_frames
self
.
num_channels
=
num_channels
self
.
num_channels
=
num_channels
self
.
bits_per_sample
=
bits_per_sample
@
_mod_utils
.
deprecated
(
'Please migrate to `AudioMetaData`.'
,
'0.9.0'
)
@
_mod_utils
.
deprecated
(
'Please migrate to `AudioMetaData`.'
,
'0.9.0'
)
...
...
torchaudio/backend/sox_io_backend.py
View file @
99ed7183
...
@@ -32,7 +32,8 @@ def info(
...
@@ -32,7 +32,8 @@ def info(
# Cast to str in case type is `pathlib.Path`
# Cast to str in case type is `pathlib.Path`
filepath
=
str
(
filepath
)
filepath
=
str
(
filepath
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
,
format
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
,
format
)
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
())
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
(),
sinfo
.
get_bits_per_sample
())
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/sox/io.cpp
View file @
99ed7183
...
@@ -13,10 +13,12 @@ namespace sox_io {
...
@@ -13,10 +13,12 @@ namespace sox_io {
SignalInfo
::
SignalInfo
(
SignalInfo
::
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
)
const
int64_t
num_frames_
,
const
int64_t
bits_per_sample_
)
:
sample_rate
(
sample_rate_
),
:
sample_rate
(
sample_rate_
),
num_channels
(
num_channels_
),
num_channels
(
num_channels_
),
num_frames
(
num_frames_
){};
num_frames
(
num_frames_
),
bits_per_sample
(
bits_per_sample_
){};
int64_t
SignalInfo
::
getSampleRate
()
const
{
int64_t
SignalInfo
::
getSampleRate
()
const
{
return
sample_rate
;
return
sample_rate
;
...
@@ -30,6 +32,10 @@ int64_t SignalInfo::getNumFrames() const {
...
@@ -30,6 +32,10 @@ int64_t SignalInfo::getNumFrames() const {
return
num_frames
;
return
num_frames
;
}
}
int64_t
SignalInfo
::
getBitsPerSample
()
const
{
return
bits_per_sample
;
}
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
c10
::
optional
<
std
::
string
>&
format
)
{
c10
::
optional
<
std
::
string
>&
format
)
{
...
@@ -46,7 +52,8 @@ c10::intrusive_ptr<SignalInfo> get_info(
...
@@ -46,7 +52,8 @@ c10::intrusive_ptr<SignalInfo> get_info(
return
c10
::
make_intrusive
<
SignalInfo
>
(
return
c10
::
make_intrusive
<
SignalInfo
>
(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
));
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
encoding
.
bits_per_sample
));
}
}
namespace
{
namespace
{
...
...
torchaudio/csrc/sox/io.h
View file @
99ed7183
...
@@ -15,14 +15,17 @@ struct SignalInfo : torch::CustomClassHolder {
...
@@ -15,14 +15,17 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t
sample_rate
;
int64_t
sample_rate
;
int64_t
num_channels
;
int64_t
num_channels
;
int64_t
num_frames
;
int64_t
num_frames
;
int64_t
bits_per_sample
;
SignalInfo
(
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
);
const
int64_t
num_frames_
,
const
int64_t
bits_per_sample_
);
int64_t
getSampleRate
()
const
;
int64_t
getSampleRate
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumFrames
()
const
;
int64_t
getNumFrames
()
const
;
int64_t
getBitsPerSample
()
const
;
};
};
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
...
...
torchaudio/csrc/sox/register.cpp
View file @
99ed7183
...
@@ -42,7 +42,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...
@@ -42,7 +42,8 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m
.
class_
<
torchaudio
::
sox_io
::
SignalInfo
>
(
"SignalInfo"
)
m
.
class_
<
torchaudio
::
sox_io
::
SignalInfo
>
(
"SignalInfo"
)
.
def
(
"get_sample_rate"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getSampleRate
)
.
def
(
"get_sample_rate"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getSampleRate
)
.
def
(
"get_num_channels"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_channels"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
);
.
def
(
"get_num_frames"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getNumFrames
)
.
def
(
"get_bits_per_sample"
,
&
torchaudio
::
sox_io
::
SignalInfo
::
getBitsPerSample
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info
);
m
.
def
(
"torchaudio::sox_io_get_info"
,
&
torchaudio
::
sox_io
::
get_info
);
m
.
def
(
m
.
def
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment