Unverified Commit be442561 authored by moto's avatar moto Committed by GitHub
Browse files

Add format override to load and related I/O functions (#1104)

parent c4f0a11b
......@@ -8,6 +8,7 @@ from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExtension,
get_asset_path,
get_sinusoid,
get_wav_data,
save_wav,
......@@ -243,3 +244,21 @@ class TestFileFormats(TempDirMixin, PytorchTestCase):
assert sr == expected_sr
self.assertEqual(found, expected)
@skipIfNoExtension
class TestApplyEffectFileWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension
libsox does not check header for mp3
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
effects = [['band', '300', '10']]
path = get_asset_path("mp3_without_ext")
_, sr = sox_effects.apply_effects_file(path, effects, format="mp3")
assert sr == 16000
......@@ -167,3 +167,20 @@ class TestInfoOpus(PytorchTestCase):
assert info.sample_rate == 48000
assert info.num_frames == 32768
assert info.num_channels == num_channels
@skipIfNoExtension
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing `format` allows to read mp3 without extension
libsox does not check header for mp3
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
sinfo = sox_io_backend.info(path, format="mp3")
assert sinfo.sample_rate == 16000
import os
import itertools
from torchaudio.backend import sox_io_backend
......@@ -355,26 +354,18 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected)
@skipIfNoExec('sox')
@skipIfNoExtension
class TestLoadExtensionLess(TempDirMixin, PytorchTestCase):
"""Given `format` parameter, `sox_io_backend.load` can load files without extension"""
original = None
path = None
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension
def _make_file(self, format_):
sample_rate = 8000
path = self.get_temp_path(f'test.{format_}')
sox_utils.gen_audio_file(f'{path}', sample_rate, num_channels=2)
self.original = sox_io_backend.load(path)[0]
self.path = os.path.splitext(path)[0]
os.rename(path, self.path)
@parameterized.expand([
('WAV', ), ('wav', ), ('MP3', ), ('mp3', ), ('FLAC', ), ('flac',),
], name_func=name_func)
def test_format(self, format_):
"""Providing format allows to read file without extension"""
self._make_file(format_)
found, _ = sox_io_backend.load(self.path)
self.assertEqual(found, self.original)
libsox does not check header for mp3
https://github.com/pytorch/audio/issues/1040
The file was generated with the following command
ffmpeg -f lavfi -i "sine=frequency=1000:duration=5" -ar 16000 -f mp3 test_noext
"""
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
assert sr == 16000
......@@ -84,7 +84,6 @@ ExternalProject_Add(libsox
DOWNLOAD_DIR ${ARCHIVE_DIR}
URL https://downloads.sourceforge.net/project/sox/sox/14.4.2/sox-14.4.2.tar.bz2
URL_HASH SHA256=81a6956d4330e75b5827316e44ae381e6f1e8928003c6aa45896da9041ea149c
PATCH_COMMAND patch -p0 < ${CMAKE_CURRENT_SOURCE_DIR}/patch/libsox.patch
# OpenMP is by default compiled against GNU OpenMP, which conflicts with the version of OpenMP that PyTorch uses.
# See https://github.com/pytorch/audio/pull/1026
CONFIGURE_COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/build_codec_helper.sh ${CMAKE_CURRENT_SOURCE_DIR}/src/libsox/configure ${COMMON_ARGS} --with-lame --with-flac --with-mad --with-oggvorbis --without-alsa --without-coreaudio --without-png --without-oss --without-sndfile --with-opus --with-amrwb --with-amrnb --disable-openmp
......
--- src/formats.c 2014-10-27 02:55:50.000000000 +0000
+++ src/formats_new.c 2020-11-18 19:14:17.398689371 +0000
@@ -91,6 +91,21 @@
if (ext && !strcasecmp(ext, "snd"))
CHECK(sndr , 7, 1, "" , 0, 2, "\0")
+
+#if defined HAVE_MP3 && (defined STATIC_MP3 || !defined HAVE_LIBLTDL)
+ // http://www.mp3-tech.org/programmer/frame_header.html
+ // Check the first two bytes are
+ // expected 1111_1111 111X_X01X
+ // mask \xEEE6 (1111_1111 1110_0110)
+ // masked value \xEEE2 (1111_1111 1110_0010)
+ if (len >= 2 && !memcmp(data, "\xFF", 1)) {
+ unsigned char second_byte = data[1];
+ unsigned char mask = 0xE6;
+ unsigned char expected = 0xE2;
+ if ((second_byte & mask) == expected)
+ return "mp3";
+ }
+#endif
#undef CHECK
#if HAVE_MAGIC
......@@ -45,6 +45,7 @@ def load(
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
......@@ -99,6 +100,8 @@ def load(
channels_first (bool):
When True, the returned Tensor has dimension ``[channel, time]``.
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
format (str, optional):
Not used. PySoundFile does not accept format hint.
Returns:
torch.Tensor:
......
......@@ -9,20 +9,27 @@ from .common import AudioMetaData
@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> AudioMetaData:
def info(
filepath: str,
format: Optional[str] = None,
) -> AudioMetaData:
"""Get signal information of an audio file.
Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects,
but is annotated as ``str`` for TorchScript compatibility.
format (str, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
from header or extension,
Returns:
AudioMetaData: Metadata of the given audio.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
return AudioMetaData(sinfo.get_sample_rate(), sinfo.get_num_frames(), sinfo.get_num_channels())
......@@ -33,6 +40,7 @@ def load(
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
......@@ -93,6 +101,10 @@ def load(
channels_first (bool):
When True, the returned Tensor has dimension ``[channel, time]``.
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
format (str, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
from header or extension,
Returns:
torch.Tensor:
......@@ -103,7 +115,7 @@ def load(
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first)
filepath, frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
......
......@@ -49,7 +49,14 @@ TORCH_LIBRARY(torchaudio, m) {
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info);
m.def(
"torchaudio::sox_io_load_audio_file",
"torchaudio::sox_io_load_audio_file("
"str path,"
"int? frame_offset=None,"
"int? num_frames=None,"
"bool? normalize=True,"
"bool? channels_first=False,"
"str? format=None"
") -> __torch__.torch.classes.torchaudio.TensorSignal",
&torchaudio::sox_io::load_audio_file);
m.def(
"torchaudio::sox_io_save_audio_file",
......
......@@ -92,13 +92,14 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
const std::string path,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first) {
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
// Open input file
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf);
......
......@@ -19,7 +19,8 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file(
const std::string path,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first);
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);
} // namespace sox_effects
} // namespace torchaudio
......
......@@ -30,12 +30,14 @@ int64_t SignalInfo::getNumFrames() const {
return num_frames;
}
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
c10::intrusive_ptr<SignalInfo> get_info(
const std::string& path,
c10::optional<std::string>& format) {
SoxFormat sf(sox_open_read(
path.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error opening audio file");
......@@ -52,7 +54,8 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first) {
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
......@@ -78,7 +81,7 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
}
return torchaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first);
path, effects, normalize, channels_first, format);
}
void save_audio_file(
......
......@@ -21,14 +21,17 @@ struct SignalInfo : torch::CustomClassHolder {
int64_t getNumFrames() const;
};
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
c10::intrusive_ptr<SignalInfo> get_info(
const std::string& path,
c10::optional<std::string>& format);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const std::string& path,
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first);
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);
void save_audio_file(
const std::string& file_name,
......
from typing import List, Tuple
from typing import List, Tuple, Optional
import torch
......@@ -157,6 +157,7 @@ def apply_effects_file(
effects: List[List[str]],
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Apply sox effects to the audio file and load the resulting data as Tensor
......@@ -180,6 +181,10 @@ def apply_effects_file(
than integer WAV type.
channels_first (bool): When True, the returned Tensor has dimension ``[channel, time]``.
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
format (str, optional):
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
from header or extension,
Returns:
Tuple[torch.Tensor, int]: Resulting Tensor and sample rate.
......@@ -249,5 +254,6 @@ def apply_effects_file(
"""
# Get string representation of 'path' in case Path object is passed
path = str(path)
signal = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first)
signal = torch.ops.torchaudio.sox_effects_apply_effects_file(
path, effects, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment