Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
180ede8e
Unverified
Commit
180ede8e
authored
Jul 08, 2020
by
moto
Committed by
GitHub
Jul 08, 2020
Browse files
Get rid of typedefs/SignalInfo and replace AudioMetaData (#761)
parent
68cc72da
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
75 additions
and
115 deletions
+75
-115
test/sox_io_backend/test_info.py
test/sox_io_backend/test_info.py
+18
-18
test/sox_io_backend/test_torchscript.py
test/sox_io_backend/test_torchscript.py
+4
-4
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+10
-2
torchaudio/csrc/register.cpp
torchaudio/csrc/register.cpp
+6
-11
torchaudio/csrc/sox_effects.h
torchaudio/csrc/sox_effects.h
+0
-1
torchaudio/csrc/sox_io.cpp
torchaudio/csrc/sox_io.cpp
+22
-2
torchaudio/csrc/sox_io.h
torchaudio/csrc/sox_io.h
+15
-2
torchaudio/csrc/typedefs.cpp
torchaudio/csrc/typedefs.cpp
+0
-23
torchaudio/csrc/typedefs.h
torchaudio/csrc/typedefs.h
+0
-23
torchaudio/extension/extension.py
torchaudio/extension/extension.py
+0
-29
No files found.
test/sox_io_backend/test_info.py
View file @
180ede8e
...
...
@@ -33,9 +33,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
...
@@ -49,9 +49,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
...
...
@@ -67,10 +67,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression
=
bit_rate
,
duration
=
duration
,
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
# mp3 does not preserve the number of samples
# assert info.
get_
num_frames
()
== sample_rate * duration
assert
info
.
get_
num_channels
()
==
num_channels
# assert info.num_frames == sample_rate * duration
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
...
...
@@ -86,9 +86,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression
=
compression_level
,
duration
=
duration
,
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
...
...
@@ -104,9 +104,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression
=
quality_level
,
duration
=
duration
,
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
sample_rate
==
sample_rate
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
num_channels
==
num_channels
@
skipIfNoExtension
...
...
@@ -120,6 +120,6 @@ class TestInfoOpus(PytorchTestCase):
"""`sox_io_backend.info` can check opus file correcty"""
path
=
get_asset_path
(
'io'
,
f
'
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus'
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
48000
assert
info
.
get_
num_frames
()
==
32768
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
sample_rate
==
48000
assert
info
.
num_frames
==
32768
assert
info
.
num_channels
==
num_channels
test/sox_io_backend/test_torchscript.py
View file @
180ede8e
...
...
@@ -20,7 +20,7 @@ from .common import (
)
def
py_info_func
(
filepath
:
str
)
->
torch
.
classes
.
torchaudio
.
SignalInfo
:
def
py_info_func
(
filepath
:
str
)
->
torch
audio
.
backend
.
sox_io_backend
.
AudioMetaData
:
return
torchaudio
.
info
(
filepath
)
...
...
@@ -63,9 +63,9 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_info
=
py_info_func
(
audio_path
)
ts_info
=
ts_info_func
(
audio_path
)
assert
py_info
.
get_
sample_rate
()
==
ts_info
.
get_
sample_rate
()
assert
py_info
.
get_
num_frames
()
==
ts_info
.
get_
num_frames
()
assert
py_info
.
get_
num_channels
()
==
ts_info
.
get_
num_channels
()
assert
py_info
.
sample_rate
==
ts_info
.
sample_rate
assert
py_info
.
num_frames
==
ts_info
.
num_frames
assert
py_info
.
num_channels
==
ts_info
.
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
...
torchaudio/backend/sox_io_backend.py
View file @
180ede8e
...
...
@@ -6,10 +6,18 @@ from torchaudio._internal import (
)
class
AudioMetaData
:
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
):
self
.
sample_rate
=
sample_rate
self
.
num_frames
=
num_frames
self
.
num_channels
=
num_channels
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
def
info
(
filepath
:
str
)
->
torch
.
classes
.
torchaudio
.
SignalInfo
:
def
info
(
filepath
:
str
)
->
AudioMetaData
:
"""Get signal information of an audio file."""
return
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
)
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
())
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/register.cpp
View file @
180ede8e
...
...
@@ -4,21 +4,10 @@
#include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
namespace
{
////////////////////////////////////////////////////////////////////////////////
// typedefs.h
////////////////////////////////////////////////////////////////////////////////
static
auto
registerSignalInfo
=
torch
::
class_
<
SignalInfo
>
(
"torchaudio"
,
"SignalInfo"
)
.
def
(
torch
::
init
<
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
"get_sample_rate"
,
&
SignalInfo
::
getSampleRate
)
.
def
(
"get_num_channels"
,
&
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_frames"
,
&
SignalInfo
::
getNumFrames
);
////////////////////////////////////////////////////////////////////////////////
// sox_utils.h
////////////////////////////////////////////////////////////////////////////////
...
...
@@ -32,6 +21,12 @@ static auto registerTensorSignal =
////////////////////////////////////////////////////////////////////////////////
// sox_io.h
////////////////////////////////////////////////////////////////////////////////
static
auto
registerSignalInfo
=
torch
::
class_
<
sox_io
::
SignalInfo
>
(
"torchaudio"
,
"SignalInfo"
)
.
def
(
"get_sample_rate"
,
&
sox_io
::
SignalInfo
::
getSampleRate
)
.
def
(
"get_num_channels"
,
&
sox_io
::
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_frames"
,
&
sox_io
::
SignalInfo
::
getNumFrames
);
static
auto
registerGetInfo
=
torch
::
RegisterOperators
().
op
(
torch
::
RegisterOperators
::
options
()
.
schema
(
...
...
torchaudio/csrc/sox_effects.h
View file @
180ede8e
...
...
@@ -2,7 +2,6 @@
#define TORCHAUDIO_SOX_EFFECTS_H
#include <torch/script.h>
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
namespace
sox_effects
{
...
...
torchaudio/csrc/sox_io.cpp
View file @
180ede8e
...
...
@@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils;
namespace
torchaudio
{
namespace
sox_io
{
c10
::
intrusive_ptr
<
torchaudio
::
SignalInfo
>
get_info
(
const
std
::
string
&
path
)
{
SignalInfo
::
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
)
:
sample_rate
(
sample_rate_
),
num_channels
(
num_channels_
),
num_frames
(
num_frames_
){};
int64_t
SignalInfo
::
getSampleRate
()
const
{
return
sample_rate
;
}
int64_t
SignalInfo
::
getNumChannels
()
const
{
return
num_channels
;
}
int64_t
SignalInfo
::
getNumFrames
()
const
{
return
num_frames
;
}
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
const
std
::
string
&
path
)
{
SoxFormat
sf
(
sox_open_read
(
path
.
c_str
(),
/*signal=*/
nullptr
,
...
...
@@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
throw
std
::
runtime_error
(
"Error opening audio file"
);
}
return
c10
::
make_intrusive
<
torchaudio
::
SignalInfo
>
(
return
c10
::
make_intrusive
<
SignalInfo
>
(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
));
...
...
torchaudio/csrc/sox_io.h
View file @
180ede8e
...
...
@@ -3,12 +3,25 @@
#include <torch/script.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
namespace
sox_io
{
c10
::
intrusive_ptr
<
torchaudio
::
SignalInfo
>
get_info
(
const
std
::
string
&
path
);
struct
SignalInfo
:
torch
::
CustomClassHolder
{
int64_t
sample_rate
;
int64_t
num_channels
;
int64_t
num_frames
;
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
);
int64_t
getSampleRate
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumFrames
()
const
;
};
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
const
std
::
string
&
path
);
c10
::
intrusive_ptr
<
torchaudio
::
sox_utils
::
TensorSignal
>
load_audio_file
(
const
std
::
string
&
path
,
...
...
torchaudio/csrc/typedefs.cpp
deleted
100644 → 0
View file @
68cc72da
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
SignalInfo
::
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
)
:
sample_rate
(
sample_rate_
),
num_channels
(
num_channels_
),
num_frames
(
num_frames_
){};
int64_t
SignalInfo
::
getSampleRate
()
const
{
return
sample_rate
;
}
int64_t
SignalInfo
::
getNumChannels
()
const
{
return
num_channels
;
}
int64_t
SignalInfo
::
getNumFrames
()
const
{
return
num_frames
;
}
}
// namespace torchaudio
torchaudio/csrc/typedefs.h
deleted
100644 → 0
View file @
68cc72da
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H
#include <torch/script.h>
namespace
torchaudio
{
struct
SignalInfo
:
torch
::
CustomClassHolder
{
int64_t
sample_rate
;
int64_t
num_channels
;
int64_t
num_frames
;
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
);
int64_t
getSampleRate
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumFrames
()
const
;
};
}
// namespace torchaudio
#endif
torchaudio/extension/extension.py
View file @
180ede8e
...
...
@@ -12,38 +12,9 @@ def _init_extension():
_init_script_module
(
ext
)
else
:
warnings
.
warn
(
'torchaudio C++ extension is not available.'
)
_init_dummy_module
()
def
_init_script_module
(
module
):
path
=
importlib
.
util
.
find_spec
(
module
).
origin
torch
.
classes
.
load_library
(
path
)
torch
.
ops
.
load_library
(
path
)
def
_init_dummy_module
():
class
SignalInfo
:
"""Data class for audio format information
Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def
__init__
(
self
,
sample_rate
:
int
,
num_channels
:
int
,
num_frames
:
int
):
self
.
sample_rate
=
sample_rate
self
.
num_channels
=
num_channels
self
.
num_frames
=
num_frames
def
get_sample_rate
(
self
):
return
self
.
sample_rate
def
get_num_channels
(
self
):
return
self
.
num_channels
def
get_num_frames
(
self
):
return
self
.
num_frames
DummyModule
=
namedtuple
(
'torchaudio'
,
[
'SignalInfo'
])
module
=
DummyModule
(
SignalInfo
)
setattr
(
torch
.
classes
,
'torchaudio'
,
module
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment