Unverified Commit fec39cd5 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Added check for invalid input file (#3932)

parent 1cbcb2b1
import collections import collections
import math import math
import os import os
import time
import unittest import unittest
from fractions import Fraction from fractions import Fraction
...@@ -9,6 +8,7 @@ import numpy as np ...@@ -9,6 +8,7 @@ import numpy as np
import torch import torch
import torchvision.io as io import torchvision.io as io
from numpy.random import randint from numpy.random import randint
from torchvision import set_video_backend
from torchvision.io import _HAS_VIDEO_OPT from torchvision.io import _HAS_VIDEO_OPT
from common_utils import PY39_SKIP from common_utils import PY39_SKIP
from _assert_utils import assert_equal from _assert_utils import assert_equal
...@@ -23,9 +23,6 @@ except ImportError: ...@@ -23,9 +23,6 @@ except ImportError:
av = None av = None
from urllib.error import URLError
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
CheckerConfig = [ CheckerConfig = [
...@@ -1277,6 +1274,15 @@ class TestVideoReader(unittest.TestCase): ...@@ -1277,6 +1274,15 @@ class TestVideoReader(unittest.TestCase):
self.assertGreaterEqual( self.assertGreaterEqual(
torch.mean(torch.isclose(audio.float(), arr).float()), 0.99) torch.mean(torch.isclose(audio.float(), arr).float()), 0.99)
def test_invalid_file(self):
set_video_backend('video_reader')
with self.assertRaises(RuntimeError):
io.read_video('foo.mp4')
set_video_backend('pyav')
with self.assertRaises(RuntimeError):
io.read_video('foo.mp4')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import gc import gc
import math import math
import os
import re import re
import warnings import warnings
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -258,6 +259,9 @@ def read_video( ...@@ -258,6 +259,9 @@ def read_video(
from torchvision import get_video_backend from torchvision import get_video_backend
if not os.path.exists(filename):
raise RuntimeError(f'File not found: {filename}')
if get_video_backend() != "pyav": if get_video_backend() != "pyav":
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment