Unverified Commit 57f7f522 authored by moto-meta's avatar moto-meta Committed by GitHub
Browse files

Remove FFmpeg compat load/info function

Differential Revision: D50229857

Pull Request resolved: https://github.com/pytorch/audio/pull/3652
parent e65e4726
......@@ -15,7 +15,6 @@ set(
stream_writer/packet_writer.cpp
stream_writer/stream_writer.cpp
stream_writer/tensor_converter.cpp
compat.cpp
)
set(
......
#include <libtorio/ffmpeg/stream_reader/stream_reader.h>
#include <torch/script.h>
#include <stdexcept>
namespace torchaudio {
namespace io {
namespace {
torch::Tensor _load_audio(
StreamReader& s,
int i,
const c10::optional<std::string>& filter,
const bool& channels_first) {
s.add_audio_stream(i, -1, -1, filter, {}, {});
s.process_all_packets();
auto chunk = s.pop_chunks()[0];
TORCH_CHECK(chunk, "Failed to decode audio.");
auto waveform = chunk.value().frames;
return channels_first ? waveform.transpose(0, 1) : waveform;
}
std::tuple<torch::Tensor, int64_t> load(
const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<std::string>& filter,
const bool& channels_first) {
StreamReader s{src, format, {}};
auto i = s.find_best_audio_stream();
auto sample_rate = s.get_src_stream_info(i).sample_rate;
auto waveform = _load_audio(s, i, filter, channels_first);
return {waveform, sample_rate};
}
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> info(
const std::string& src,
const c10::optional<std::string>& format) {
StreamReader s{src, format, {}};
auto i = s.find_best_audio_stream();
auto sinfo = s.get_src_stream_info(i);
int64_t num_frames = [&]() {
if (sinfo.num_frames == 0) {
torch::Tensor waveform = _load_audio(s, i, {}, false);
return waveform.size(0);
}
return sinfo.num_frames;
}();
return {
static_cast<int64_t>(sinfo.sample_rate),
static_cast<int64_t>(num_frames),
static_cast<int64_t>(sinfo.num_channels),
static_cast<int64_t>(sinfo.bits_per_sample),
sinfo.codec_name};
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::compat_load", &load);
m.def("torchaudio::compat_info", &info);
}
} // namespace
} // namespace io
} // namespace torchaudio
......@@ -5,35 +5,22 @@ from typing import BinaryIO, Optional, Tuple, Union
import torch
import torchaudio
from torchaudio.io import StreamWriter
from .backend import Backend
from .common import AudioMetaData
if torchaudio._extension._FFMPEG_EXT is not None:
StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
else:
StreamReaderFileObj = object
InputType = Union[BinaryIO, str, os.PathLike]
def info_audio(
src: str,
format: Optional[str],
) -> AudioMetaData:
i = torch.ops.torchaudio.compat_info(src, format)
return AudioMetaData(i[0], i[1], i[2], i[3], i[4].upper())
def info_audio_fileobj(
src,
src: InputType,
format: Optional[str],
buffer_size: int = 4096,
) -> AudioMetaData:
s = StreamReaderFileObj(src, format, None, buffer_size)
i = s.find_best_audio_stream()
sinfo = s.get_src_stream_info(i)
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
sinfo = s.get_src_stream_info(s.default_audio_stream)
if sinfo.num_frames == 0:
waveform = _load_audio_fileobj(s)
waveform = _load_audio(s)
num_frames = waveform.size(1)
else:
num_frames = sinfo.num_frames
......@@ -42,7 +29,7 @@ def info_audio_fileobj(
num_frames,
sinfo.num_channels,
sinfo.bits_per_sample,
sinfo.codec_name.upper(),
sinfo.codec.upper(),
)
......@@ -73,35 +60,22 @@ def _get_load_filter(
return "{},{}".format(atrim, aformat)
def _load_audio_fileobj(
s: StreamReaderFileObj,
def _load_audio(
s: "torchaudio.io.StreamReader",
filter: Optional[str] = None,
channels_first: bool = True,
) -> torch.Tensor:
i = s.find_best_audio_stream()
s.add_audio_stream(i, -1, -1, filter, None, None)
s.add_audio_stream(-1, -1, filter_desc=filter)
s.process_all_packets()
chunk = s.pop_chunks()[0]
if chunk is None:
raise RuntimeError("Failed to decode audio.")
waveform = chunk.frames
waveform = chunk._elem
return waveform.T if channels_first else waveform
def load_audio(
src: str,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
filter = _get_load_filter(frame_offset, num_frames, convert)
return torch.ops.torchaudio.compat_load(src, format, filter, channels_first)
def load_audio_fileobj(
src: BinaryIO,
src: InputType,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
......@@ -109,11 +83,12 @@ def load_audio_fileobj(
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
demuxer = "ogg" if format == "vorbis" else format
s = StreamReaderFileObj(src, demuxer, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
if hasattr(src, "read") and format == "vorbis":
format = "ogg"
s = torchaudio.io.StreamReader(src, format, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio_fileobj(s, filter, channels_first)
waveform = _load_audio(s, filter, channels_first)
return waveform, sample_rate
......@@ -245,7 +220,7 @@ def _parse_save_args(
def save_audio(
uri: Union[BinaryIO, str, os.PathLike],
uri: InputType,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
......@@ -268,7 +243,7 @@ def save_audio(
if channels_first:
src = src.T
s = StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s.add_audio_stream(
sample_rate,
num_channels=src.size(-1),
......@@ -301,18 +276,15 @@ def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
class FFmpegBackend(Backend):
@staticmethod
def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
else:
metadata = info_audio(os.path.normpath(uri), format)
def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
metadata = info_audio(uri, format, buffer_size)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
@staticmethod
def load(
uri: Union[BinaryIO, str, os.PathLike],
uri: InputType,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
......@@ -320,22 +292,11 @@ class FFmpegBackend(Backend):
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
if hasattr(uri, "read"):
return load_audio_fileobj(
uri,
frame_offset,
num_frames,
normalize,
channels_first,
format,
buffer_size,
)
else:
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
@staticmethod
def save(
uri: Union[BinaryIO, str, os.PathLike],
uri: InputType,
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
......@@ -356,9 +317,9 @@ class FFmpegBackend(Backend):
)
@staticmethod
def can_decode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
def can_decode(uri: InputType, format: Optional[str]) -> bool:
return True
@staticmethod
def can_encode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool:
def can_encode(uri: InputType, format: Optional[str]) -> bool:
return True
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
from typing import BinaryIO, Dict, Iterator, Optional, Tuple, TypeVar, Union
......@@ -527,7 +528,7 @@ class StreamReader:
elif hasattr(src, "read"):
self._be = _StreamReaderFileObj(src, format, option, buffer_size)
else:
self._be = _StreamReader(str(src), format, option)
self._be = _StreamReader(os.path.normpath(src), format, option)
i = self._be.find_best_audio_stream()
self._default_audio_stream = None if i < 0 else i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment