Commit b56f60bf authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fail on Python if sox_io info/load does not succeed (#2423)

Summary:
Extracted from https://github.com/pytorch/audio/issues/2419. Move the failure of sox_io from C++ to Python layer.

Pull Request resolved: https://github.com/pytorch/audio/pull/2423

Reviewed By: carolineechen

Differential Revision: D36766152

Pulled By: mthrok

fbshipit-source-id: 53f897a608e97b81ebe5df29577374d88ce178f3
parent c209b70d
...@@ -435,7 +435,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -435,7 +435,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
num_channels = 2 num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)]) comments = "metadata=" + " ".join(["value" for _ in range(1000)])
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"): with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments) sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
with self._set_buffer_size(16384): with self._set_buffer_size(16384):
...@@ -583,5 +583,5 @@ class TestInfoNoSuchFile(PytorchTestCase): ...@@ -583,5 +583,5 @@ class TestInfoNoSuchFile(PytorchTestCase):
When attempted to get info on a non-existing file, error message must contain the file path. When attempted to get info on a non-existing file, error message must contain the file path.
""" """
path = "non_existing_audio.wav" path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)): with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.info(path) sox_io_backend.info(path)
...@@ -627,5 +627,5 @@ class TestLoadNoSuchFile(PytorchTestCase): ...@@ -627,5 +627,5 @@ class TestLoadNoSuchFile(PytorchTestCase):
When attempted to load a non-existing file, error message must contain the file path. When attempted to load a non-existing file, error message must contain the file path.
""" """
path = "non_existing_audio.wav" path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)): with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.load(path) sox_io_backend.load(path)
...@@ -8,6 +8,37 @@ from torchaudio._internal import module_utils as _mod_utils ...@@ -8,6 +8,37 @@ from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData from .common import AudioMetaData
# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_info(filepath: str, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(filepath))
def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(fileobj))
# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
raise RuntimeError("Failed to load audio from {}".format(filepath))
def _fail_load_fileobj(fileobj, *args, **kwargs):
raise RuntimeError(f"Failed to load audio from {fileobj}")
_fallback_info = _fail_info
_fallback_info_fileobj = _fail_info_fileobj
_fallback_load = _fail_load
_fallback_load_fileobj = _fail_load_fileobj
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
def info( def info(
filepath: str, filepath: str,
...@@ -46,11 +77,14 @@ def info( ...@@ -46,11 +77,14 @@ def info(
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(filepath, "read"): if hasattr(filepath, "read"):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format) sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo) if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
assert sinfo is not None # for TorchScript compatibility if sinfo is not None:
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
return _fallback_info(filepath, format)
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
...@@ -145,15 +179,19 @@ def load( ...@@ -145,15 +179,19 @@ def load(
""" """
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(filepath, "read"): if hasattr(filepath, "read"):
return torchaudio._torchaudio.load_audio_fileobj( ret = torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format filepath, frame_offset, num_frames, normalize, channels_first, format
) )
if ret is not None:
return ret
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
ret = torch.ops.torchaudio.sox_io_load_audio_file( ret = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format filepath, frame_offset, num_frames, normalize, channels_first, format
) )
assert ret is not None # for TorchScript compatibility if ret is not None:
return ret return ret
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_sox() @_mod_utils.requires_sox()
......
...@@ -83,7 +83,10 @@ auto apply_effects_fileobj( ...@@ -83,7 +83,10 @@ auto apply_effects_fileobj(
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0 // In case of streamed data, length can be 0
validate_input_memfile(sf); if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
// Prepare output buffer // Prepare output buffer
std::vector<sox_sample_t> out_buffer; std::vector<sox_sample_t> out_buffer;
......
...@@ -60,8 +60,10 @@ auto get_info_fileobj(py::object fileobj, c10::optional<std::string> format) ...@@ -60,8 +60,10 @@ auto get_info_fileobj(py::object fileobj, c10::optional<std::string> format)
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0 if (static_cast<sox_format_t*>(sf) == nullptr ||
validate_input_memfile(sf); sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return c10::optional<MetaDataTuple>{};
}
return std::forward_as_tuple( return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate), static_cast<int64_t>(sf->signal.rate),
......
...@@ -103,7 +103,10 @@ auto apply_effects_file( ...@@ -103,7 +103,10 @@ auto apply_effects_file(
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path); if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
......
...@@ -19,7 +19,11 @@ c10::optional<MetaDataTuple> get_info_file( ...@@ -19,7 +19,11 @@ c10::optional<MetaDataTuple> get_info_file(
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path); if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
return std::forward_as_tuple( return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate), static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels), static_cast<int64_t>(sf->signal.length / sf->signal.channels),
......
...@@ -95,10 +95,6 @@ void validate_input_file(const SoxFormat& sf, const std::string& path) { ...@@ -95,10 +95,6 @@ void validate_input_file(const SoxFormat& sf, const std::string& path) {
} }
} }
void validate_input_memfile(const SoxFormat& sf) {
return validate_input_file(sf, "<in memory buffer>");
}
void validate_input_tensor(const torch::Tensor tensor) { void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) { if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU."); throw std::runtime_error("Input tensor has to be on CPU.");
......
...@@ -52,13 +52,6 @@ struct SoxFormat { ...@@ -52,13 +52,6 @@ struct SoxFormat {
sox_format_t* fd_; sox_format_t* fd_;
}; };
///
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf, const std::string& path);
/// Verify that input memory buffer has known encoding, and not empty
void validate_input_memfile(const SoxFormat& sf);
/// ///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
void validate_input_tensor(const torch::Tensor); void validate_input_tensor(const torch::Tensor);
......
...@@ -272,8 +272,12 @@ def apply_effects_file( ...@@ -272,8 +272,12 @@ def apply_effects_file(
""" """
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if hasattr(path, "read"): if hasattr(path, "read"):
return torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format) ret = torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
if ret is None:
raise RuntimeError("Failed to load audio from {}".format(path))
return ret
path = os.fspath(path) path = os.fspath(path)
ret = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format) ret = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
assert ret is not None if ret is not None:
return ret return ret
raise RuntimeError("Failed to load audio from {}".format(path))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment