Unverified Commit 180ede8e authored by moto's avatar moto Committed by GitHub
Browse files

Get rid of typedefs/SignalInfo and replace AudioMetaData (#761)

parent 68cc72da
...@@ -33,9 +33,9 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -33,9 +33,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.sample_rate == sample_rate
assert info.get_num_frames() == sample_rate * duration assert info.num_frames == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
...@@ -49,9 +49,9 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -49,9 +49,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate) save_wav(path, data, sample_rate)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.sample_rate == sample_rate
assert info.get_num_frames() == sample_rate * duration assert info.num_frames == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[8000, 16000], [8000, 16000],
...@@ -67,10 +67,10 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -67,10 +67,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression=bit_rate, duration=duration, compression=bit_rate, duration=duration,
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.sample_rate == sample_rate
# mp3 does not preserve the number of samples # mp3 does not preserve the number of samples
# assert info.get_num_frames() == sample_rate * duration # assert info.num_frames == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[8000, 16000], [8000, 16000],
...@@ -86,9 +86,9 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -86,9 +86,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression=compression_level, duration=duration, compression=compression_level, duration=duration,
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.sample_rate == sample_rate
assert info.get_num_frames() == sample_rate * duration assert info.num_frames == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[8000, 16000], [8000, 16000],
...@@ -104,9 +104,9 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -104,9 +104,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression=quality_level, duration=duration, compression=quality_level, duration=duration,
) )
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.sample_rate == sample_rate
assert info.get_num_frames() == sample_rate * duration assert info.num_frames == sample_rate * duration
assert info.get_num_channels() == num_channels assert info.num_channels == num_channels
@skipIfNoExtension @skipIfNoExtension
...@@ -120,6 +120,6 @@ class TestInfoOpus(PytorchTestCase): ...@@ -120,6 +120,6 @@ class TestInfoOpus(PytorchTestCase):
"""`sox_io_backend.info` can check opus file correcty""" """`sox_io_backend.info` can check opus file correcty"""
path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus') path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == 48000 assert info.sample_rate == 48000
assert info.get_num_frames() == 32768 assert info.num_frames == 32768
assert info.get_num_channels() == num_channels assert info.num_channels == num_channels
...@@ -20,7 +20,7 @@ from .common import ( ...@@ -20,7 +20,7 @@ from .common import (
) )
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
return torchaudio.info(filepath) return torchaudio.info(filepath)
...@@ -63,9 +63,9 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -63,9 +63,9 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_info = py_info_func(audio_path) py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path) ts_info = ts_info_func(audio_path)
assert py_info.get_sample_rate() == ts_info.get_sample_rate() assert py_info.sample_rate == ts_info.sample_rate
assert py_info.get_num_frames() == ts_info.get_num_frames() assert py_info.num_frames == ts_info.num_frames
assert py_info.get_num_channels() == ts_info.get_num_channels() assert py_info.num_channels == ts_info.num_channels
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
......
...@@ -6,10 +6,18 @@ from torchaudio._internal import ( ...@@ -6,10 +6,18 @@ from torchaudio._internal import (
) )
class AudioMetaData:
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo: def info(filepath: str) -> AudioMetaData:
"""Get signal information of an audio file.""" """Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
......
...@@ -4,21 +4,10 @@ ...@@ -4,21 +4,10 @@
#include <torchaudio/csrc/sox_effects.h> #include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_io.h> #include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h> #include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio { namespace torchaudio {
namespace { namespace {
////////////////////////////////////////////////////////////////////////////////
// typedefs.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_frames", &SignalInfo::getNumFrames);
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// sox_utils.h // sox_utils.h
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
...@@ -32,6 +21,12 @@ static auto registerTensorSignal = ...@@ -32,6 +21,12 @@ static auto registerTensorSignal =
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// sox_io.h // sox_io.h
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo =
torch::class_<sox_io::SignalInfo>("torchaudio", "SignalInfo")
.def("get_sample_rate", &sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &sox_io::SignalInfo::getNumFrames);
static auto registerGetInfo = torch::RegisterOperators().op( static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options() torch::RegisterOperators::options()
.schema( .schema(
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#define TORCHAUDIO_SOX_EFFECTS_H #define TORCHAUDIO_SOX_EFFECTS_H
#include <torch/script.h> #include <torch/script.h>
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio { namespace torchaudio {
namespace sox_effects { namespace sox_effects {
......
...@@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils; ...@@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils;
namespace torchaudio { namespace torchaudio {
namespace sox_io { namespace sox_io {
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) { SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}
int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
SoxFormat sf(sox_open_read( SoxFormat sf(sox_open_read(
path.c_str(), path.c_str(),
/*signal=*/nullptr, /*signal=*/nullptr,
...@@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) { ...@@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
throw std::runtime_error("Error opening audio file"); throw std::runtime_error("Error opening audio file");
} }
return c10::make_intrusive<torchaudio::SignalInfo>( return c10::make_intrusive<SignalInfo>(
static_cast<int64_t>(sf->signal.rate), static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels), static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels)); static_cast<int64_t>(sf->signal.length / sf->signal.channels));
......
...@@ -3,12 +3,25 @@ ...@@ -3,12 +3,25 @@
#include <torch/script.h> #include <torch/script.h>
#include <torchaudio/csrc/sox_utils.h> #include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio { namespace torchaudio {
namespace sox_io { namespace sox_io {
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path); struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
};
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file( c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const std::string& path, const std::string& path,
......
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}
int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
} // namespace torchaudio
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H
#include <torch/script.h>
namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
};
} // namespace torchaudio
#endif
...@@ -12,38 +12,9 @@ def _init_extension(): ...@@ -12,38 +12,9 @@ def _init_extension():
_init_script_module(ext) _init_script_module(ext)
else: else:
warnings.warn('torchaudio C++ extension is not available.') warnings.warn('torchaudio C++ extension is not available.')
_init_dummy_module()
def _init_script_module(module): def _init_script_module(module):
path = importlib.util.find_spec(module).origin path = importlib.util.find_spec(module).origin
torch.classes.load_library(path) torch.classes.load_library(path)
torch.ops.load_library(path) torch.ops.load_library(path)
def _init_dummy_module():
class SignalInfo:
"""Data class for audio format information
Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_frames = num_frames
def get_sample_rate(self):
return self.sample_rate
def get_num_channels(self):
return self.num_channels
def get_num_frames(self):
return self.num_frames
DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
setattr(torch.classes, 'torchaudio', module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment