Unverified Commit 180ede8e authored by moto's avatar moto Committed by GitHub
Browse files

Get rid of typedefs/SignalInfo and replace AudioMetaData (#761)

parent 68cc72da
......@@ -33,9 +33,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
......@@ -49,9 +49,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product(
[8000, 16000],
......@@ -67,10 +67,10 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression=bit_rate, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.sample_rate == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
# assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product(
[8000, 16000],
......@@ -86,9 +86,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression=compression_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@parameterized.expand(list(itertools.product(
[8000, 16000],
......@@ -104,9 +104,9 @@ class TestInfo(TempDirMixin, PytorchTestCase):
compression=quality_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration
assert info.get_num_channels() == num_channels
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
@skipIfNoExtension
......@@ -120,6 +120,6 @@ class TestInfoOpus(PytorchTestCase):
"""`sox_io_backend.info` can check opus file correcty"""
path = get_asset_path('io', f'{bitrate}_{compression_level}_{num_channels}ch.opus')
info = sox_io_backend.info(path)
assert info.get_sample_rate() == 48000
assert info.get_num_frames() == 32768
assert info.get_num_channels() == num_channels
assert info.sample_rate == 48000
assert info.num_frames == 32768
assert info.num_channels == num_channels
......@@ -20,7 +20,7 @@ from .common import (
)
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
def py_info_func(filepath: str) -> torchaudio.backend.sox_io_backend.AudioMetaData:
return torchaudio.info(filepath)
......@@ -63,9 +63,9 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)
assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_frames() == ts_info.get_num_frames()
assert py_info.get_num_channels() == ts_info.get_num_channels()
assert py_info.sample_rate == ts_info.sample_rate
assert py_info.num_frames == ts_info.num_frames
assert py_info.num_channels == ts_info.num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
......
......@@ -6,10 +6,18 @@ from torchaudio._internal import (
)
class AudioMetaData:
def __init__(self, sample_rate: int, num_frames: int, num_channels: int):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
def info(filepath: str) -> AudioMetaData:
"""Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
@_mod_utils.requires_module('torchaudio._torchaudio')
......
......@@ -4,21 +4,10 @@
#include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
namespace {
////////////////////////////////////////////////////////////////////////////////
// typedefs.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_frames", &SignalInfo::getNumFrames);
////////////////////////////////////////////////////////////////////////////////
// sox_utils.h
////////////////////////////////////////////////////////////////////////////////
......@@ -32,6 +21,12 @@ static auto registerTensorSignal =
////////////////////////////////////////////////////////////////////////////////
// sox_io.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo =
torch::class_<sox_io::SignalInfo>("torchaudio", "SignalInfo")
.def("get_sample_rate", &sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &sox_io::SignalInfo::getNumFrames);
static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
......
......@@ -2,7 +2,6 @@
#define TORCHAUDIO_SOX_EFFECTS_H
#include <torch/script.h>
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
namespace sox_effects {
......
......@@ -8,7 +8,27 @@ using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_io {
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}
int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
......@@ -19,7 +39,7 @@ c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
throw std::runtime_error("Error opening audio file");
}
return c10::make_intrusive<torchaudio::SignalInfo>(
return c10::make_intrusive<SignalInfo>(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
......
......@@ -3,12 +3,25 @@
#include <torch/script.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
namespace sox_io {
c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path);
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
};
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const std::string& path,
......
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}
int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
} // namespace torchaudio
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H
#include <torch/script.h>
namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
};
} // namespace torchaudio
#endif
......@@ -12,38 +12,9 @@ def _init_extension():
_init_script_module(ext)
else:
warnings.warn('torchaudio C++ extension is not available.')
_init_dummy_module()
def _init_script_module(module):
path = importlib.util.find_spec(module).origin
torch.classes.load_library(path)
torch.ops.load_library(path)
def _init_dummy_module():
class SignalInfo:
"""Data class for audio format information
Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_frames: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_frames = num_frames
def get_sample_rate(self):
return self.sample_rate
def get_num_channels(self):
return self.num_channels
def get_num_frames(self):
return self.num_frames
DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
setattr(torch.classes, 'torchaudio', module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment