Unverified Commit 8b93bd68 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Simplified C++ interface for sox_io's get_info_file() (#1232)


Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 3651412b
......@@ -11,24 +11,6 @@ import torchaudio
from .common import AudioMetaData
@torch.jit.unused
def _info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
sinfo.get_encoding(),
)
@_mod_utils.requires_module('torchaudio._torchaudio')
def info(
filepath: str,
......@@ -61,14 +43,12 @@ def info(
AudioMetaData: Metadata of the given audio.
"""
if not torch.jit.is_scripting():
return _info(filepath, format)
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
sinfo.get_encoding())
return AudioMetaData(*sinfo)
@_mod_utils.requires_module('torchaudio._torchaudio')
......
......@@ -10,38 +10,6 @@ using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_io {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_,
const int64_t bits_per_sample_,
const std::string encoding_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_),
bits_per_sample(bits_per_sample_),
encoding(encoding_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}
int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
int64_t SignalInfo::getBitsPerSample() const {
return bits_per_sample;
}
std::string SignalInfo::getEncoding() const {
return encoding;
}
namespace {
std::string get_encoding(sox_encoding_t encoding) {
......@@ -77,7 +45,7 @@ std::string get_encoding(sox_encoding_t encoding) {
} // namespace
c10::intrusive_ptr<SignalInfo> get_info_file(
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
const std::string& path,
c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
......@@ -90,10 +58,10 @@ c10::intrusive_ptr<SignalInfo> get_info_file(
throw std::runtime_error("Error opening audio file");
}
return c10::make_intrusive<SignalInfo>(
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
......@@ -347,15 +315,6 @@ void save_audio_fileobj(
#endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def(
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample)
.def("get_encoding", &torchaudio::sox_io::SignalInfo::getEncoding);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
"torchaudio::sox_io_load_audio_file",
......
......@@ -11,27 +11,7 @@
namespace torchaudio {
namespace sox_io {
struct SignalInfo : torch::CustomClassHolder {
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
int64_t bits_per_sample;
std::string encoding;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_,
const int64_t bits_per_sample_,
const std::string encoding_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
int64_t getBitsPerSample() const;
std::string getEncoding() const;
};
c10::intrusive_ptr<SignalInfo> get_info_file(
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
const std::string& path,
c10::optional<std::string>& format);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment