"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "33d10af28fcfb4d41ab7fb97d84c8ac2317576d5"
Unverified Commit 57f7f522 authored by moto-meta's avatar moto-meta Committed by GitHub
Browse files

Remove FFmpeg compat load/info function

Differential Revision: D50229857

Pull Request resolved: https://github.com/pytorch/audio/pull/3652
parent e65e4726
...@@ -15,7 +15,6 @@ set( ...@@ -15,7 +15,6 @@ set(
stream_writer/packet_writer.cpp stream_writer/packet_writer.cpp
stream_writer/stream_writer.cpp stream_writer/stream_writer.cpp
stream_writer/tensor_converter.cpp stream_writer/tensor_converter.cpp
compat.cpp
) )
set( set(
......
#include <libtorio/ffmpeg/stream_reader/stream_reader.h>
#include <torch/script.h>
#include <stdexcept>
namespace torchaudio {
namespace io {
namespace {
torch::Tensor _load_audio(
StreamReader& s,
int i,
const c10::optional<std::string>& filter,
const bool& channels_first) {
s.add_audio_stream(i, -1, -1, filter, {}, {});
s.process_all_packets();
auto chunk = s.pop_chunks()[0];
TORCH_CHECK(chunk, "Failed to decode audio.");
auto waveform = chunk.value().frames;
return channels_first ? waveform.transpose(0, 1) : waveform;
}
std::tuple<torch::Tensor, int64_t> load(
const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<std::string>& filter,
const bool& channels_first) {
StreamReader s{src, format, {}};
auto i = s.find_best_audio_stream();
auto sample_rate = s.get_src_stream_info(i).sample_rate;
auto waveform = _load_audio(s, i, filter, channels_first);
return {waveform, sample_rate};
}
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> info(
const std::string& src,
const c10::optional<std::string>& format) {
StreamReader s{src, format, {}};
auto i = s.find_best_audio_stream();
auto sinfo = s.get_src_stream_info(i);
int64_t num_frames = [&]() {
if (sinfo.num_frames == 0) {
torch::Tensor waveform = _load_audio(s, i, {}, false);
return waveform.size(0);
}
return sinfo.num_frames;
}();
return {
static_cast<int64_t>(sinfo.sample_rate),
static_cast<int64_t>(num_frames),
static_cast<int64_t>(sinfo.num_channels),
static_cast<int64_t>(sinfo.bits_per_sample),
sinfo.codec_name};
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::compat_load", &load);
m.def("torchaudio::compat_info", &info);
}
} // namespace
} // namespace io
} // namespace torchaudio
...@@ -5,35 +5,22 @@ from typing import BinaryIO, Optional, Tuple, Union ...@@ -5,35 +5,22 @@ from typing import BinaryIO, Optional, Tuple, Union
import torch import torch
import torchaudio import torchaudio
from torchaudio.io import StreamWriter
from .backend import Backend from .backend import Backend
from .common import AudioMetaData from .common import AudioMetaData
if torchaudio._extension._FFMPEG_EXT is not None: InputType = Union[BinaryIO, str, os.PathLike]
StreamReaderFileObj = torchaudio._extension._FFMPEG_EXT.StreamReaderFileObj
else:
StreamReaderFileObj = object
def info_audio( def info_audio(
src: str, src: InputType,
format: Optional[str],
) -> AudioMetaData:
i = torch.ops.torchaudio.compat_info(src, format)
return AudioMetaData(i[0], i[1], i[2], i[3], i[4].upper())
def info_audio_fileobj(
src,
format: Optional[str], format: Optional[str],
buffer_size: int = 4096, buffer_size: int = 4096,
) -> AudioMetaData: ) -> AudioMetaData:
s = StreamReaderFileObj(src, format, None, buffer_size) s = torchaudio.io.StreamReader(src, format, None, buffer_size)
i = s.find_best_audio_stream() sinfo = s.get_src_stream_info(s.default_audio_stream)
sinfo = s.get_src_stream_info(i)
if sinfo.num_frames == 0: if sinfo.num_frames == 0:
waveform = _load_audio_fileobj(s) waveform = _load_audio(s)
num_frames = waveform.size(1) num_frames = waveform.size(1)
else: else:
num_frames = sinfo.num_frames num_frames = sinfo.num_frames
...@@ -42,7 +29,7 @@ def info_audio_fileobj( ...@@ -42,7 +29,7 @@ def info_audio_fileobj(
num_frames, num_frames,
sinfo.num_channels, sinfo.num_channels,
sinfo.bits_per_sample, sinfo.bits_per_sample,
sinfo.codec_name.upper(), sinfo.codec.upper(),
) )
...@@ -73,35 +60,22 @@ def _get_load_filter( ...@@ -73,35 +60,22 @@ def _get_load_filter(
return "{},{}".format(atrim, aformat) return "{},{}".format(atrim, aformat)
def _load_audio_fileobj( def _load_audio(
s: StreamReaderFileObj, s: "torchaudio.io.StreamReader",
filter: Optional[str] = None, filter: Optional[str] = None,
channels_first: bool = True, channels_first: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
i = s.find_best_audio_stream() s.add_audio_stream(-1, -1, filter_desc=filter)
s.add_audio_stream(i, -1, -1, filter, None, None)
s.process_all_packets() s.process_all_packets()
chunk = s.pop_chunks()[0] chunk = s.pop_chunks()[0]
if chunk is None: if chunk is None:
raise RuntimeError("Failed to decode audio.") raise RuntimeError("Failed to decode audio.")
waveform = chunk.frames waveform = chunk._elem
return waveform.T if channels_first else waveform return waveform.T if channels_first else waveform
def load_audio( def load_audio(
src: str, src: InputType,
frame_offset: int = 0,
num_frames: int = -1,
convert: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
filter = _get_load_filter(frame_offset, num_frames, convert)
return torch.ops.torchaudio.compat_load(src, format, filter, channels_first)
def load_audio_fileobj(
src: BinaryIO,
frame_offset: int = 0, frame_offset: int = 0,
num_frames: int = -1, num_frames: int = -1,
convert: bool = True, convert: bool = True,
...@@ -109,11 +83,12 @@ def load_audio_fileobj( ...@@ -109,11 +83,12 @@ def load_audio_fileobj(
format: Optional[str] = None, format: Optional[str] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
demuxer = "ogg" if format == "vorbis" else format if hasattr(src, "read") and format == "vorbis":
s = StreamReaderFileObj(src, demuxer, None, buffer_size) format = "ogg"
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate) s = torchaudio.io.StreamReader(src, format, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.default_audio_stream).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert) filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio_fileobj(s, filter, channels_first) waveform = _load_audio(s, filter, channels_first)
return waveform, sample_rate return waveform, sample_rate
...@@ -245,7 +220,7 @@ def _parse_save_args( ...@@ -245,7 +220,7 @@ def _parse_save_args(
def save_audio( def save_audio(
uri: Union[BinaryIO, str, os.PathLike], uri: InputType,
src: torch.Tensor, src: torch.Tensor,
sample_rate: int, sample_rate: int,
channels_first: bool = True, channels_first: bool = True,
...@@ -268,7 +243,7 @@ def save_audio( ...@@ -268,7 +243,7 @@ def save_audio(
if channels_first: if channels_first:
src = src.T src = src.T
s = StreamWriter(uri, format=muxer, buffer_size=buffer_size) s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s.add_audio_stream( s.add_audio_stream(
sample_rate, sample_rate,
num_channels=src.size(-1), num_channels=src.size(-1),
...@@ -301,18 +276,15 @@ def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str: ...@@ -301,18 +276,15 @@ def _get_bits_per_sample(encoding: str, bits_per_sample: int) -> str:
class FFmpegBackend(Backend): class FFmpegBackend(Backend):
@staticmethod @staticmethod
def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_size: int = 4096) -> AudioMetaData: def info(uri: InputType, format: Optional[str], buffer_size: int = 4096) -> AudioMetaData:
if hasattr(uri, "read"): metadata = info_audio(uri, format, buffer_size)
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
else:
metadata = info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample) metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding) metadata.encoding = _map_encoding(metadata.encoding)
return metadata return metadata
@staticmethod @staticmethod
def load( def load(
uri: Union[BinaryIO, str, os.PathLike], uri: InputType,
frame_offset: int = 0, frame_offset: int = 0,
num_frames: int = -1, num_frames: int = -1,
normalize: bool = True, normalize: bool = True,
...@@ -320,22 +292,11 @@ class FFmpegBackend(Backend): ...@@ -320,22 +292,11 @@ class FFmpegBackend(Backend):
format: Optional[str] = None, format: Optional[str] = None,
buffer_size: int = 4096, buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
if hasattr(uri, "read"): return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
return load_audio_fileobj(
uri,
frame_offset,
num_frames,
normalize,
channels_first,
format,
buffer_size,
)
else:
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)
@staticmethod @staticmethod
def save( def save(
uri: Union[BinaryIO, str, os.PathLike], uri: InputType,
src: torch.Tensor, src: torch.Tensor,
sample_rate: int, sample_rate: int,
channels_first: bool = True, channels_first: bool = True,
...@@ -356,9 +317,9 @@ class FFmpegBackend(Backend): ...@@ -356,9 +317,9 @@ class FFmpegBackend(Backend):
) )
@staticmethod @staticmethod
def can_decode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool: def can_decode(uri: InputType, format: Optional[str]) -> bool:
return True return True
@staticmethod @staticmethod
def can_encode(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str]) -> bool: def can_encode(uri: InputType, format: Optional[str]) -> bool:
return True return True
from __future__ import annotations from __future__ import annotations
import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import BinaryIO, Dict, Iterator, Optional, Tuple, TypeVar, Union from typing import BinaryIO, Dict, Iterator, Optional, Tuple, TypeVar, Union
...@@ -527,7 +528,7 @@ class StreamReader: ...@@ -527,7 +528,7 @@ class StreamReader:
elif hasattr(src, "read"): elif hasattr(src, "read"):
self._be = _StreamReaderFileObj(src, format, option, buffer_size) self._be = _StreamReaderFileObj(src, format, option, buffer_size)
else: else:
self._be = _StreamReader(str(src), format, option) self._be = _StreamReader(os.path.normpath(src), format, option)
i = self._be.find_best_audio_stream() i = self._be.find_best_audio_stream()
self._default_audio_stream = None if i < 0 else i self._default_audio_stream = None if i < 0 else i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment