Unverified Commit d272eb0f authored by Denis Kokarev's avatar Denis Kokarev Committed by GitHub
Browse files

Add file path to io error messages (#1523)

parent c4a17027
...@@ -476,3 +476,14 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase): ...@@ -476,3 +476,14 @@ class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
assert sinfo.num_frames == num_frames assert sinfo.num_frames == num_frames
assert sinfo.bits_per_sample == bits_per_sample assert sinfo.bits_per_sample == bits_per_sample
assert sinfo.encoding == get_encoding(ext, dtype) assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoSox
class TestInfoNoSuchFile(PytorchTestCase):
def test_info_fail(self):
"""
When attempted to get info on a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
sox_io_backend.info(path)
...@@ -522,3 +522,14 @@ class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): ...@@ -522,3 +522,14 @@ class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
assert sr == sample_rate assert sr == sample_rate
self.assertEqual(expected, found) self.assertEqual(expected, found)
@skipIfNoSox
class TestLoadNoSuchFile(PytorchTestCase):
def test_load_fail(self):
"""
When attempted to load a non-existing file, error message must contain the file path.
"""
path = "non_existing_audio.wav"
with self.assertRaisesRegex(RuntimeError, "^Error loading audio file: failed to open file {0}$".format(path)):
sox_io_backend.load(path)
import io import io
import os
import unittest import unittest
import torch
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from parameterized import parameterized from parameterized import parameterized
...@@ -387,3 +389,14 @@ class TestSaveParams(TempDirMixin, PytorchTestCase): ...@@ -387,3 +389,14 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
sox_io_backend.save(path, data, 8000) sox_io_backend.save(path, data, 8000)
self.assertEqual(data, expected) self.assertEqual(data, expected)
@skipIfNoSox
class TestSaveNonExistingDirectory(PytorchTestCase):
def test_save_fail(self):
"""
When attempted to save into a non-existing dir, error message must contain the file path.
"""
path = os.path.join("non_existing_directory", "foo.wav")
with self.assertRaisesRegex(RuntimeError, "^Error saving audio file: failed to open file {0}$".format(path)):
sox_io_backend.save(path, torch.zeros(1, 1), 8000)
...@@ -101,7 +101,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_file( ...@@ -101,7 +101,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_file(
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
validate_input_file(sf); validate_input_file(sf, path);
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
...@@ -204,7 +204,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj( ...@@ -204,7 +204,7 @@ std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0 // In case of streamed data, length can be 0
validate_input_file(sf); validate_input_memfile(sf);
// Prepare output buffer // Prepare output buffer
std::vector<sox_sample_t> out_buffer; std::vector<sox_sample_t> out_buffer;
......
...@@ -19,9 +19,7 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file( ...@@ -19,9 +19,7 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) { validate_input_file(sf, path);
throw std::runtime_error("Error opening audio file");
}
return std::make_tuple( return std::make_tuple(
static_cast<int64_t>(sf->signal.rate), static_cast<int64_t>(sf->signal.rate),
...@@ -123,7 +121,8 @@ void save_audio_file( ...@@ -123,7 +121,8 @@ void save_audio_file(
/*overwrite_permitted=*/nullptr)); /*overwrite_permitted=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) { if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error saving audio file: failed to open file."); throw std::runtime_error(
"Error saving audio file: failed to open file " + path);
} }
torchaudio::sox_effects_chain::SoxEffectsChain chain( torchaudio::sox_effects_chain::SoxEffectsChain chain(
...@@ -177,7 +176,7 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj( ...@@ -177,7 +176,7 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr)); /*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0 // In case of streamed data, length can be 0
validate_input_file(sf); validate_input_memfile(sf);
return std::make_tuple( return std::make_tuple(
static_cast<int64_t>(sf->signal.rate), static_cast<int64_t>(sf->signal.rate),
......
...@@ -81,15 +81,20 @@ void SoxFormat::close() { ...@@ -81,15 +81,20 @@ void SoxFormat::close() {
} }
} }
void validate_input_file(const SoxFormat& sf) { void validate_input_file(const SoxFormat& sf, const std::string& path) {
if (static_cast<sox_format_t*>(sf) == nullptr) { if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error loading audio file: failed to open file."); throw std::runtime_error(
"Error loading audio file: failed to open file " + path);
} }
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) { if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding."); throw std::runtime_error("Error loading audio file: unknown encoding.");
} }
} }
void validate_input_memfile(const SoxFormat& sf) {
return validate_input_file(sf, "<in memory buffer>");
}
void validate_input_tensor(const torch::Tensor tensor) { void validate_input_tensor(const torch::Tensor tensor) {
if (!tensor.device().is_cpu()) { if (!tensor.device().is_cpu()) {
throw std::runtime_error("Input tensor has to be on CPU."); throw std::runtime_error("Input tensor has to be on CPU.");
......
...@@ -56,7 +56,10 @@ struct SoxFormat { ...@@ -56,7 +56,10 @@ struct SoxFormat {
/// ///
/// Verify that input file is found, has known encoding, and not empty /// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf); void validate_input_file(const SoxFormat& sf, const std::string& path);
/// Verify that input memory buffer has known encoding, and not empty
void validate_input_memfile(const SoxFormat& sf);
/// ///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32 /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment