Commit fb932674 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Add FFmpeg compat save function (#3058)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3058

Adds FFmpeg-based save function.

Reviewed By: mthrok

Differential Revision: D43264858

fbshipit-source-id: ae3f89012bc2520f3de11af65348ba8f77f0acff
parent b9ef69d1
from typing import BinaryIO, Dict, Optional, Tuple import os
import sys
from typing import BinaryIO, Dict, Optional, Tuple, Union
import torch import torch
import torchaudio import torchaudio
from torchaudio.backend.common import AudioMetaData from torchaudio.backend.common import AudioMetaData
from torchaudio.io import StreamWriter
# Note: need to comply TorchScript syntax -- need annotation and no f-string nor global # Note: need to comply TorchScript syntax -- need annotation and no f-string nor global
...@@ -116,3 +119,120 @@ def load_audio_fileobj( ...@@ -116,3 +119,120 @@ def load_audio_fileobj(
) -> Tuple[torch.Tensor, int]: ) -> Tuple[torch.Tensor, int]:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size) s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
return _load_audio(s, frame_offset, num_frames, convert, channels_first) return _load_audio(s, frame_offset, num_frames, convert, channels_first)
def _get_sample_format(dtype: torch.dtype) -> str:
dtype_to_format = {
torch.uint8: "u8",
torch.int16: "s16",
torch.int32: "s32",
torch.int64: "s64",
torch.float32: "flt",
torch.float64: "dbl",
}
format = dtype_to_format.get(dtype)
if format is None:
raise ValueError(f"No format found for dtype {dtype}; dtype must be one of {list(dtype_to_format.keys())}.")
return format
def _native_endianness() -> str:
if sys.byteorder == "little":
return "le"
else:
return "be"
def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int) -> str:
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
endianness = _native_endianness()
if not encoding:
if not bits_per_sample:
# default to PCM S16
return f"pcm_s16{endianness}"
if bits_per_sample == 8:
return "pcm_u8"
return f"pcm_s{bits_per_sample}{endianness}"
if encoding == "PCM_S":
if not bits_per_sample:
bits_per_sample = 16
if bits_per_sample == 8:
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
return f"pcm_s{bits_per_sample}{endianness}"
elif encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "pcm_u8"
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
elif encoding == "PCM_F":
if not bits_per_sample:
bits_per_sample = 32
if bits_per_sample in (32, 64):
return f"pcm_f{bits_per_sample}{endianness}"
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
elif encoding == "ULAW":
if bits_per_sample in (None, 8):
return "pcm_mulaw"
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
elif encoding == "ALAW":
if bits_per_sample in (None, 8):
return "pcm_alaw"
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
raise ValueError(f"WAV encoding {encoding} is not supported.")
def _get_encoder(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int) -> str:
if format == "wav":
return _get_encoder_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
return "flac"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
return "vorbis"
return format
def _get_encoder_format(format: str, bits_per_sample: Optional[int]) -> str:
if format == "flac":
if not bits_per_sample:
return "s16"
if bits_per_sample == 24:
return "s32"
if bits_per_sample == 16:
return "s16"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bits_per_sample} specified).")
return None
# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
def save_audio(
uri: Union[BinaryIO, str, os.PathLike],
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
if hasattr(uri, "write") and format is None:
raise RuntimeError("'format' is required when saving to file object.")
s = StreamWriter(uri, format=format, buffer_size=buffer_size)
if format is None:
tokens = str(uri).split(".")
if len(tokens) > 1:
format = tokens[-1].lower()
if channels_first:
src = src.T
s.add_audio_stream(
sample_rate,
src.size(-1),
_get_sample_format(src.dtype),
_get_encoder(src.dtype, format, encoding, bits_per_sample),
{"strict": "experimental"},
_get_encoder_format(format, bits_per_sample),
)
with s.open():
s.write_audio_chunk(0, src)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment