Unverified Commit f8eac89b authored by moto's avatar moto Committed by GitHub
Browse files

Add SignalInfo typedef, and extension module (#718)

This is a part of PRs to add new "sox_io" backend. https://github.com/pytorch/audio/pull/726

This PR adds `SignalInfo` structure, which is data exchange interface between Python and C++ in coming TorchScript-based sox IO backend.
For the case, where C++ extension is not available (i.e. Windows), this PR also adds dummy class and module that will be substituted.
This logic is implemented in `torchaudio.extension` moduel.
parent bc1df488
from . import extension
from torchaudio._internal import module_utils as _mod_utils
from torchaudio import (
compliance,
......
#ifndef TORCHAUDIO_REGISTER_H
#define TORCHAUDIO_REGISTER_H
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
namespace {
static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>())
.def("get_sample_rate", &SignalInfo::getSampleRate)
.def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples);
} // namespace
} // namespace torchaudio
#endif
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_samples(num_samples_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}
int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumSamples() const {
return num_samples;
}
} // namespace torchaudio
#ifndef TORCHAUDIO_TYPDEFS_H
#define TORCHAUDIO_TYPDEFS_H
#include <torch/script.h>
namespace torchaudio {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_samples;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_samples_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumSamples() const;
};
} // namespace torchaudio
#endif
from .extension import (
_init_extension,
)
_init_extension()
del _init_extension
import warnings
import importlib
from collections import namedtuple
import torch
from torchaudio._internal import module_utils as _mod_utils
def _init_extension():
ext = 'torchaudio._torchaudio'
if _mod_utils.is_module_available(ext):
_init_script_module(ext)
else:
warnings.warn('torchaudio C++ extension is not available.')
_init_dummy_module()
def _init_script_module(module):
path = importlib.util.find_spec(module).origin
torch.classes.load_library(path)
torch.ops.load_library(path)
def _init_dummy_module():
class SignalInfo:
"""Data class for audio format information
Used when torchaudio C++ extension is not available for annotating
sox_io backend functions so that torchaudio is still importable
without extension.
This class has to implement the same interface as C++ equivalent.
"""
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
self.sample_rate = sample_rate
self.num_channels = num_channels
self.num_samples = num_samples
def get_sample_rate(self):
return self.sample_rate
def get_num_channels(self):
return self.num_channels
def get_num_samples(self):
return self.num_samples
DummyModule = namedtuple('torchaudio', ['SignalInfo'])
module = DummyModule(SignalInfo)
setattr(torch.classes, 'torchaudio', module)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment