Unverified Commit 03b1d38b authored by Edgar Andrés Margffoy Tuay's avatar Edgar Andrés Margffoy Tuay Committed by GitHub
Browse files

PR: Improve handling of truncated/incomplete and corrupt JPEG images (#2471)

* Add corruption cases

* Read jpeg headers until exhaustion

* Minor error correction

* Add test script

* Raise exception when image is truncated

* Add test

* Skip damaged_jpeg folder

* Compare against basename

* Remove unused test file
parent a568c7f1
Copyright 2019 The TensorFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
import os import os
import glob
import unittest import unittest
import sys import sys
...@@ -10,11 +11,15 @@ import numpy as np ...@@ -10,11 +11,15 @@ import numpy as np
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder") IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
def get_images(directory, img_ext): def get_images(directory, img_ext):
assert os.path.isdir(directory) assert os.path.isdir(directory)
for root, _, files in os.walk(directory): for root, _, files in os.walk(directory):
if os.path.basename(root) == 'damaged_jpeg':
continue
for fl in files: for fl in files:
_, ext = os.path.splitext(fl) _, ext = os.path.splitext(fl)
if ext == img_ext: if ext == img_ext:
...@@ -44,6 +49,21 @@ class ImageTester(unittest.TestCase): ...@@ -44,6 +49,21 @@ class ImageTester(unittest.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8)) decode_jpeg(torch.empty((100), dtype=torch.uint8))
def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise)
bad_huff = os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')
try:
_ = read_jpeg(bad_huff)
except RuntimeError:
self.assertTrue(False)
# Truncated images should raise an exception
truncated_images = glob.glob(
os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
for image_path in truncated_images:
with self.assertRaises(RuntimeError):
read_jpeg(image_path)
def test_read_png(self): def test_read_png(self):
# Check across .png # Check across .png
for img_path in get_images(IMAGE_DIR, ".png"): for img_path in get_images(IMAGE_DIR, ".png"):
......
...@@ -48,7 +48,10 @@ static void torch_jpeg_init_source(j_decompress_ptr cinfo) {} ...@@ -48,7 +48,10 @@ static void torch_jpeg_init_source(j_decompress_ptr cinfo) {}
static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) { static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src; torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
// No more data. Probably an incomplete image; just output EOI. // No more data. Probably an incomplete image; Raise exception.
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
strcpy(jpegLastErrorMsg, "Image is incomplete or truncated");
longjmp(myerr->setjmp_buffer, 1);
src->pub.next_input_byte = EOI_BUFFER; src->pub.next_input_byte = EOI_BUFFER;
src->pub.bytes_in_buffer = 1; src->pub.bytes_in_buffer = 1;
return TRUE; return TRUE;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment