Unverified Commit 8b93bd68 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Simplified C++ interface for sox_io's get_info_file() (#1232)


Co-authored-by: default avatarPrabhat Roy <prabhatroy@fb.com>
parent 3651412b
...@@ -11,24 +11,6 @@ import torchaudio ...@@ -11,24 +11,6 @@ import torchaudio
from .common import AudioMetaData from .common import AudioMetaData
@torch.jit.unused
def _info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
sinfo = torch.ops.torchaudio.sox_io_get_info(os.fspath(filepath), format)
return AudioMetaData(
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
sinfo.get_encoding(),
)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
def info( def info(
filepath: str, filepath: str,
...@@ -61,14 +43,12 @@ def info( ...@@ -61,14 +43,12 @@ def info(
AudioMetaData: Metadata of the given audio. AudioMetaData: Metadata of the given audio.
""" """
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
return _info(filepath, format) if hasattr(filepath, 'read'):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData( return AudioMetaData(*sinfo)
sinfo.get_sample_rate(),
sinfo.get_num_frames(),
sinfo.get_num_channels(),
sinfo.get_bits_per_sample(),
sinfo.get_encoding())
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_module('torchaudio._torchaudio')
......
...@@ -10,38 +10,6 @@ using namespace torchaudio::sox_utils; ...@@ -10,38 +10,6 @@ using namespace torchaudio::sox_utils;
namespace torchaudio { namespace torchaudio {
namespace sox_io { namespace sox_io {
SignalInfo::SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_,
const int64_t bits_per_sample_,
const std::string encoding_)
: sample_rate(sample_rate_),
num_channels(num_channels_),
num_frames(num_frames_),
bits_per_sample(bits_per_sample_),
encoding(encoding_){};
int64_t SignalInfo::getSampleRate() const {
return sample_rate;
}
int64_t SignalInfo::getNumChannels() const {
return num_channels;
}
int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
int64_t SignalInfo::getBitsPerSample() const {
return bits_per_sample;
}
std::string SignalInfo::getEncoding() const {
return encoding;
}
namespace { namespace {
std::string get_encoding(sox_encoding_t encoding) { std::string get_encoding(sox_encoding_t encoding) {
...@@ -77,7 +45,7 @@ std::string get_encoding(sox_encoding_t encoding) { ...@@ -77,7 +45,7 @@ std::string get_encoding(sox_encoding_t encoding) {
} // namespace } // namespace
c10::intrusive_ptr<SignalInfo> get_info_file( std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
const std::string& path, const std::string& path,
c10::optional<std::string>& format) { c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read( SoxFormat sf(sox_open_read(
...@@ -90,10 +58,10 @@ c10::intrusive_ptr<SignalInfo> get_info_file( ...@@ -90,10 +58,10 @@ c10::intrusive_ptr<SignalInfo> get_info_file(
throw std::runtime_error("Error opening audio file"); throw std::runtime_error("Error opening audio file");
} }
return c10::make_intrusive<SignalInfo>( return std::make_tuple(
static_cast<int64_t>(sf->signal.rate), static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels), static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample), static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding)); get_encoding(sf->encoding.encoding));
} }
...@@ -347,15 +315,6 @@ void save_audio_fileobj( ...@@ -347,15 +315,6 @@ void save_audio_fileobj(
#endif // TORCH_API_INCLUDE_EXTENSION_H #endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<torchaudio::sox_io::SignalInfo>("SignalInfo")
.def("get_sample_rate", &torchaudio::sox_io::SignalInfo::getSampleRate)
.def("get_num_channels", &torchaudio::sox_io::SignalInfo::getNumChannels)
.def("get_num_frames", &torchaudio::sox_io::SignalInfo::getNumFrames)
.def(
"get_bits_per_sample",
&torchaudio::sox_io::SignalInfo::getBitsPerSample)
.def("get_encoding", &torchaudio::sox_io::SignalInfo::getEncoding);
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file); m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def( m.def(
"torchaudio::sox_io_load_audio_file", "torchaudio::sox_io_load_audio_file",
......
...@@ -11,27 +11,7 @@ ...@@ -11,27 +11,7 @@
namespace torchaudio { namespace torchaudio {
namespace sox_io { namespace sox_io {
struct SignalInfo : torch::CustomClassHolder { std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
int64_t sample_rate;
int64_t num_channels;
int64_t num_frames;
int64_t bits_per_sample;
std::string encoding;
SignalInfo(
const int64_t sample_rate_,
const int64_t num_channels_,
const int64_t num_frames_,
const int64_t bits_per_sample_,
const std::string encoding_);
int64_t getSampleRate() const;
int64_t getNumChannels() const;
int64_t getNumFrames() const;
int64_t getBitsPerSample() const;
std::string getEncoding() const;
};
c10::intrusive_ptr<SignalInfo> get_info_file(
const std::string& path, const std::string& path,
c10::optional<std::string>& format); c10::optional<std::string>& format);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment