Commit b56f60bf authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Fail on Python if sox_io info/load does not succeed (#2423)

Summary:
Extracted from https://github.com/pytorch/audio/issues/2419. Move the failure of sox_io from C++ to Python layer.

Pull Request resolved: https://github.com/pytorch/audio/pull/2423

Reviewed By: carolineechen

Differential Revision: D36766152

Pulled By: mthrok

fbshipit-source-id: 53f897a608e97b81ebe5df29577374d88ce178f3
parent c209b70d
......@@ -435,7 +435,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
num_channels = 2
comments = "metadata=" + " ".join(["value" for _ in range(1000)])
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file:"):
with self.assertRaisesRegex(RuntimeError, "Failed to fetch metadata from"):
sinfo = self._query_fileobj(ext, dtype, sample_rate, num_channels, num_frames, comments=comments)
with self._set_buffer_size(16384):
......@@ -583,5 +583,5 @@ class TestInfoNoSuchFile(PytorchTestCase):
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.info(path)
......@@ -627,5 +627,5 @@ class TestLoadNoSuchFile(PytorchTestCase):
When attempted to load a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
with self.assertRaisesRegex(RuntimeError, path):
sox_io_backend.load(path)
......@@ -8,6 +8,37 @@ from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData
# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_info(filepath: str, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(filepath))
def _fail_info_fileobj(fileobj, format: Optional[str]) -> AudioMetaData:
raise RuntimeError("Failed to fetch metadata from {}".format(fileobj))
# Note: need to comply TorchScript syntax -- need annotation and no f-string
def _fail_load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
raise RuntimeError("Failed to load audio from {}".format(filepath))
def _fail_load_fileobj(fileobj, *args, **kwargs):
raise RuntimeError(f"Failed to load audio from {fileobj}")
_fallback_info = _fail_info
_fallback_info_fileobj = _fail_info_fileobj
_fallback_load = _fail_load
_fallback_load_fileobj = _fail_load_fileobj
@_mod_utils.requires_sox()
def info(
filepath: str,
......@@ -46,11 +77,14 @@ def info(
if not torch.jit.is_scripting():
if hasattr(filepath, "read"):
sinfo = torchaudio._torchaudio.get_info_fileobj(filepath, format)
return AudioMetaData(*sinfo)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info_fileobj(filepath, format)
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
assert sinfo is not None # for TorchScript compatibility
return AudioMetaData(*sinfo)
if sinfo is not None:
return AudioMetaData(*sinfo)
return _fallback_info(filepath, format)
@_mod_utils.requires_sox()
......@@ -145,15 +179,19 @@ def load(
"""
if not torch.jit.is_scripting():
if hasattr(filepath, "read"):
return torchaudio._torchaudio.load_audio_fileobj(
ret = torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
if ret is not None:
return ret
return _fallback_load_fileobj(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath = os.fspath(filepath)
ret = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
assert ret is not None # for TorchScript compatibility
return ret
if ret is not None:
return ret
return _fallback_load(filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_sox()
......
......@@ -83,7 +83,10 @@ auto apply_effects_fileobj(
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
......
......@@ -60,8 +60,10 @@ auto get_info_fileobj(py::object fileobj, c10::optional<std::string> format)
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return c10::optional<MetaDataTuple>{};
}
return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate),
......
......@@ -103,7 +103,10 @@ auto apply_effects_file(
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
......
......@@ -19,7 +19,11 @@ c10::optional<MetaDataTuple> get_info_file(
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf, path);
if (static_cast<sox_format_t*>(sf) == nullptr ||
sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
return {};
}
return std::forward_as_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
......
......@@ -95,10 +95,6 @@ void validate_input_file(const SoxFormat& sf, const std::string& path) {
}
}
void validate_input_memfile(const SoxFormat& sf) {
return validate_input_file(sf, "<in memory buffer>");
}
void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU.");
......
......@@ -52,13 +52,6 @@ struct SoxFormat {
sox_format_t* fd_;
};
///
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf, const std::string& path);
/// Verify that input memory buffer has known encoding, and not empty
void validate_input_memfile(const SoxFormat& sf);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
void validate_input_tensor(const torch::Tensor);
......
......@@ -272,8 +272,12 @@ def apply_effects_file(
"""
if not torch.jit.is_scripting():
if hasattr(path, "read"):
return torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
ret = torchaudio._torchaudio.apply_effects_fileobj(path, effects, normalize, channels_first, format)
if ret is None:
raise RuntimeError("Failed to load audio from {}".format(path))
return ret
path = os.fspath(path)
ret = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
assert ret is not None
return ret
if ret is not None:
return ret
raise RuntimeError("Failed to load audio from {}".format(path))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment