Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
180ede8e
"examples/vscode:/vscode.git/clone" did not exist on "bf3994eea68b5841349f1616f41d0f70123a11ec"
Unverified
Commit
180ede8e
authored
Jul 08, 2020
by
moto
Committed by
GitHub
Jul 08, 2020
Browse files
Get rid of typedefs/SignalInfo and replace AudioMetaData (#761)
parent
68cc72da
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
75 additions
and
115 deletions
+75
-115
test/sox_io_backend/test_info.py
test/sox_io_backend/test_info.py
+18
-18
test/sox_io_backend/test_torchscript.py
test/sox_io_backend/test_torchscript.py
+4
-4
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+10
-2
torchaudio/csrc/register.cpp
torchaudio/csrc/register.cpp
+6
-11
torchaudio/csrc/sox_effects.h
torchaudio/csrc/sox_effects.h
+0
-1
torchaudio/csrc/sox_io.cpp
torchaudio/csrc/sox_io.cpp
+22
-2
torchaudio/csrc/sox_io.h
torchaudio/csrc/sox_io.h
+15
-2
torchaudio/csrc/typedefs.cpp
torchaudio/csrc/typedefs.cpp
+0
-23
torchaudio/csrc/typedefs.h
torchaudio/csrc/typedefs.h
+0
-23
torchaudio/extension/extension.py
torchaudio/extension/extension.py
+0
-29
No files found.
test/sox_io_backend/test_info.py
View file @
180ede8e
...
@@ -33,9 +33,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -33,9 +33,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
@@ -49,9 +49,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -49,9 +49,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
data
=
get_wav_data
(
dtype
,
num_channels
,
normalize
=
False
,
num_frames
=
duration
*
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -67,10 +67,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -67,10 +67,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression
=
bit_rate
,
duration
=
duration
,
compression
=
bit_rate
,
duration
=
duration
,
)
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
# mp3 does not preserve the number of samples
# mp3 does not preserve the number of samples
# assert info.
get_
num_frames
()
== sample_rate * duration
# assert info.num_frames == sample_rate * duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -86,9 +86,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -86,9 +86,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression
=
compression_level
,
duration
=
duration
,
compression
=
compression_level
,
duration
=
duration
,
)
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
num_channels
==
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
8000
,
16000
],
[
8000
,
16000
],
...
@@ -104,9 +104,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
...
@@ -104,9 +104,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression
=
quality_level
,
duration
=
duration
,
compression
=
quality_level
,
duration
=
duration
,
)
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
sample_rate
assert
info
.
sample_rate
==
sample_rate
assert
info
.
get_
num_frames
()
==
sample_rate
*
duration
assert
info
.
num_frames
==
sample_rate
*
duration
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
num_channels
==
num_channels
@
skipIfNoExtension
@
skipIfNoExtension
...
@@ -120,6 +120,6 @@ class TestInfoOpus(PytorchTestCase):
...
@@ -120,6 +120,6 @@ class TestInfoOpus(PytorchTestCase):
"""`sox_io_backend.info` can check opus file correcty"""
"""`sox_io_backend.info` can check opus file correcty"""
path
=
get_asset_path
(
'io'
,
f
'
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus'
)
path
=
get_asset_path
(
'io'
,
f
'
{
bitrate
}
_
{
compression_level
}
_
{
num_channels
}
ch.opus'
)
info
=
sox_io_backend
.
info
(
path
)
info
=
sox_io_backend
.
info
(
path
)
assert
info
.
get_
sample_rate
()
==
48000
assert
info
.
sample_rate
==
48000
assert
info
.
get_
num_frames
()
==
32768
assert
info
.
num_frames
==
32768
assert
info
.
get_
num_channels
()
==
num_channels
assert
info
.
num_channels
==
num_channels
test/sox_io_backend/test_torchscript.py
View file @
180ede8e
...
@@ -20,7 +20,7 @@ from .common import (
...
@@ -20,7 +20,7 @@ from .common import (
)
)
def
py_info_func
(
filepath
:
str
)
->
torch
.
classes
.
torchaudio
.
SignalInfo
:
def
py_info_func
(
filepath
:
str
)
->
torch
audio
.
backend
.
sox_io_backend
.
AudioMetaData
:
return
torchaudio
.
info
(
filepath
)
return
torchaudio
.
info
(
filepath
)
...
@@ -63,9 +63,9 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
...
@@ -63,9 +63,9 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_info
=
py_info_func
(
audio_path
)
py_info
=
py_info_func
(
audio_path
)
ts_info
=
ts_info_func
(
audio_path
)
ts_info
=
ts_info_func
(
audio_path
)
assert
py_info
.
get_
sample_rate
()
==
ts_info
.
get_
sample_rate
()
assert
py_info
.
sample_rate
==
ts_info
.
sample_rate
assert
py_info
.
get_
num_frames
()
==
ts_info
.
get_
num_frames
()
assert
py_info
.
num_frames
==
ts_info
.
num_frames
assert
py_info
.
get_
num_channels
()
==
ts_info
.
get_
num_channels
()
assert
py_info
.
num_channels
==
ts_info
.
num_channels
@
parameterized
.
expand
(
list
(
itertools
.
product
(
@
parameterized
.
expand
(
list
(
itertools
.
product
(
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
[
'float32'
,
'int32'
,
'int16'
,
'uint8'
],
...
...
torchaudio/backend/sox_io_backend.py
View file @
180ede8e
...
@@ -6,10 +6,18 @@ from torchaudio._internal import (
...
@@ -6,10 +6,18 @@ from torchaudio._internal import (
)
)
class
AudioMetaData
:
def
__init__
(
self
,
sample_rate
:
int
,
num_frames
:
int
,
num_channels
:
int
):
self
.
sample_rate
=
sample_rate
self
.
num_frames
=
num_frames
self
.
num_channels
=
num_channels
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
def
info
(
filepath
:
str
)
->
torch
.
classes
.
torchaudio
.
SignalInfo
:
def
info
(
filepath
:
str
)
->
AudioMetaData
:
"""Get signal information of an audio file."""
"""Get signal information of an audio file."""
return
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
)
sinfo
=
torch
.
ops
.
torchaudio
.
sox_io_get_info
(
filepath
)
return
AudioMetaData
(
sinfo
.
get_sample_rate
(),
sinfo
.
get_num_frames
(),
sinfo
.
get_num_channels
())
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
@
_mod_utils
.
requires_module
(
'torchaudio._torchaudio'
)
...
...
torchaudio/csrc/register.cpp
View file @
180ede8e
...
@@ -4,21 +4,10 @@
...
@@ -4,21 +4,10 @@
#include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
namespace
torchaudio
{
namespace
{
namespace
{
////////////////////////////////////////////////////////////////////////////////
// typedefs.h
////////////////////////////////////////////////////////////////////////////////
static
auto
registerSignalInfo
=
torch
::
class_
<
SignalInfo
>
(
"torchaudio"
,
"SignalInfo"
)
.
def
(
torch
::
init
<
int64_t
,
int64_t
,
int64_t
>
())
.
def
(
"get_sample_rate"
,
&
SignalInfo
::
getSampleRate
)
.
def
(
"get_num_channels"
,
&
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_frames"
,
&
SignalInfo
::
getNumFrames
);
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// sox_utils.h
// sox_utils.h
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
...
@@ -32,6 +21,12 @@ static auto registerTensorSignal =
...
@@ -32,6 +21,12 @@ static auto registerTensorSignal =
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
// sox_io.h
// sox_io.h
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
static
auto
registerSignalInfo
=
torch
::
class_
<
sox_io
::
SignalInfo
>
(
"torchaudio"
,
"SignalInfo"
)
.
def
(
"get_sample_rate"
,
&
sox_io
::
SignalInfo
::
getSampleRate
)
.
def
(
"get_num_channels"
,
&
sox_io
::
SignalInfo
::
getNumChannels
)
.
def
(
"get_num_frames"
,
&
sox_io
::
SignalInfo
::
getNumFrames
);
static
auto
registerGetInfo
=
torch
::
RegisterOperators
().
op
(
static
auto
registerGetInfo
=
torch
::
RegisterOperators
().
op
(
torch
::
RegisterOperators
::
options
()
torch
::
RegisterOperators
::
options
()
.
schema
(
.
schema
(
...
...
torchaudio/csrc/sox_effects.h
View file @
180ede8e
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#define TORCHAUDIO_SOX_EFFECTS_H
#define TORCHAUDIO_SOX_EFFECTS_H
#include <torch/script.h>
#include <torch/script.h>
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
namespace
torchaudio
{
namespace
sox_effects
{
namespace
sox_effects
{
...
...
torchaudio/csrc/sox_io.cpp
View file @
180ede8e
...
@@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils;
...
@@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils;
namespace
torchaudio
{
namespace
torchaudio
{
namespace
sox_io
{
namespace
sox_io
{
c10
::
intrusive_ptr
<
torchaudio
::
SignalInfo
>
get_info
(
const
std
::
string
&
path
)
{
SignalInfo
::
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
)
:
sample_rate
(
sample_rate_
),
num_channels
(
num_channels_
),
num_frames
(
num_frames_
){};
int64_t
SignalInfo
::
getSampleRate
()
const
{
return
sample_rate
;
}
int64_t
SignalInfo
::
getNumChannels
()
const
{
return
num_channels
;
}
int64_t
SignalInfo
::
getNumFrames
()
const
{
return
num_frames
;
}
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
const
std
::
string
&
path
)
{
SoxFormat
sf
(
sox_open_read
(
SoxFormat
sf
(
sox_open_read
(
path
.
c_str
(),
path
.
c_str
(),
/*signal=*/
nullptr
,
/*signal=*/
nullptr
,
...
@@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
...
@@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
throw
std
::
runtime_error
(
"Error opening audio file"
);
throw
std
::
runtime_error
(
"Error opening audio file"
);
}
}
return
c10
::
make_intrusive
<
torchaudio
::
SignalInfo
>
(
return
c10
::
make_intrusive
<
SignalInfo
>
(
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
rate
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
channels
),
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
));
static_cast
<
int64_t
>
(
sf
->
signal
.
length
/
sf
->
signal
.
channels
));
...
...
torchaudio/csrc/sox_io.h
View file @
180ede8e
...
@@ -3,12 +3,25 @@
...
@@ -3,12 +3,25 @@
#include <torch/script.h>
#include <torch/script.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
namespace
torchaudio
{
namespace
sox_io
{
namespace
sox_io
{
c10
::
intrusive_ptr
<
torchaudio
::
SignalInfo
>
get_info
(
const
std
::
string
&
path
);
struct
SignalInfo
:
torch
::
CustomClassHolder
{
int64_t
sample_rate
;
int64_t
num_channels
;
int64_t
num_frames
;
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
);
int64_t
getSampleRate
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumFrames
()
const
;
};
c10
::
intrusive_ptr
<
SignalInfo
>
get_info
(
const
std
::
string
&
path
);
c10
::
intrusive_ptr
<
torchaudio
::
sox_utils
::
TensorSignal
>
load_audio_file
(
c10
::
intrusive_ptr
<
torchaudio
::
sox_utils
::
TensorSignal
>
load_audio_file
(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
...
...
torchaudio/csrc/typedefs.cpp
deleted
100644 → 0
View file @
68cc72da
#include <torchaudio/csrc/typedefs.h>
namespace
torchaudio
{
SignalInfo
::
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
)
:
sample_rate
(
sample_rate_
),
num_channels
(
num_channels_
),
num_frames
(
num_frames_
){};
int64_t
SignalInfo
::
getSampleRate
()
const
{
return
sample_rate
;
}
int64_t
SignalInfo
::
getNumChannels
()
const
{
return
num_channels
;
}
int64_t
SignalInfo
::
getNumFrames
()
const
{
return
num_frames
;
}
}
// namespace torchaudio
torchaudio/csrc/typedefs.h
deleted
100644 → 0
View file @
68cc72da
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H
#include <torch/script.h>
namespace
torchaudio
{
struct
SignalInfo
:
torch
::
CustomClassHolder
{
int64_t
sample_rate
;
int64_t
num_channels
;
int64_t
num_frames
;
SignalInfo
(
const
int64_t
sample_rate_
,
const
int64_t
num_channels_
,
const
int64_t
num_frames_
);
int64_t
getSampleRate
()
const
;
int64_t
getNumChannels
()
const
;
int64_t
getNumFrames
()
const
;
};
}
// namespace torchaudio
#endif
torchaudio/extension/extension.py
View file @
180ede8e
...
@@ -12,38 +12,9 @@ def _init_extension():
...
@@ -12,38 +12,9 @@ def _init_extension():
_init_script_module
(
ext
)
_init_script_module
(
ext
)
else
:
else
:
warnings
.
warn
(
'torchaudio C++ extension is not available.'
)
warnings
.
warn
(
'torchaudio C++ extension is not available.'
)
_init_dummy_module
()
def
_init_script_module
(
module
):
def
_init_script_module
(
module
):
path
=
importlib
.
util
.
find_spec
(
module
).
origin
path
=
importlib
.
util
.
find_spec
(
module
).
origin
torch
.
classes
.
load_library
(
path
)
torch
.
classes
.
load_library
(
path
)
torch
.
ops
.
load_library
(
path
)
torch
.
ops
.
load_library
(
path
)
def
_init_dummy_module
():
class
SignalInfo
:
"""Data class for audio format information
Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def
__init__
(
self
,
sample_rate
:
int
,
num_channels
:
int
,
num_frames
:
int
):
self
.
sample_rate
=
sample_rate
self
.
num_channels
=
num_channels
self
.
num_frames
=
num_frames
def
get_sample_rate
(
self
):
return
self
.
sample_rate
def
get_num_channels
(
self
):
return
self
.
num_channels
def
get_num_frames
(
self
):
return
self
.
num_frames
DummyModule
=
namedtuple
(
'torchaudio'
,
[
'SignalInfo'
])
module
=
DummyModule
(
SignalInfo
)
setattr
(
torch
.
classes
,
'torchaudio'
,
module
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment