Unverified Commit 671569dd authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Put `load_image` function in `image_utils.py` & fix image rotation issue (#14062)

* Fix img load rotation

* Add `load_image` to `image_utils.py`

* Implement LoadImageTester

* Use hf-internal-testing dataset

* Add img utils comments

* Refactor LoadImageTester

* Import load_image under is_vision_available
parent 89766b3d
......@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Union
import numpy as np
import PIL.Image
import PIL.ImageOps
import requests
from .file_utils import _is_torch, is_torch_available
......@@ -35,6 +39,39 @@ def is_torch_tensor(obj):
return _is_torch(obj) if is_torch_available() else False
def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
"""
Loads :obj:`image` to a PIL Image.
Args:
image (:obj:`str` or :obj:`PIL.Image.Image`):
The image to convert to the PIL Image format.
Returns:
:obj:`PIL.Image.Image`: A PIL Image.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = PIL.Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
......
import os
from typing import List, Union
import requests
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
......@@ -11,6 +8,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available():
from PIL import Image
from ..image_utils import load_image
if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
......@@ -39,35 +38,13 @@ class ImageClassificationPipeline(Pipeline):
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image
def _sanitize_parameters(self, top_k=None):
postprocess_params = {}
if top_k is not None:
postprocess_params["top_k"] = top_k
return {}, {}, postprocess_params
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs):
def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
Assign labels to the image(s) passed as inputs.
......@@ -99,7 +76,7 @@ class ImageClassificationPipeline(Pipeline):
return super().__call__(images, **kwargs)
def preprocess(self, image):
image = self.load_image(image)
image = load_image(image)
model_inputs = self.feature_extractor(images=image, return_tensors="pt")
return model_inputs
......
import base64
import io
import os
from typing import Any, Dict, List, Union
import numpy as np
import requests
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
......@@ -15,6 +12,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available():
from PIL import Image
from ..image_utils import load_image
if is_torch_available():
import torch
......@@ -49,28 +48,6 @@ class ImageSegmentationPipeline(Pipeline):
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
pass
else:
raise ValueError(
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "threshold" in kwargs:
......@@ -118,7 +95,7 @@ class ImageSegmentationPipeline(Pipeline):
return torch.no_grad
def preprocess(self, image):
image = self.load_image(image)
image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt")
inputs["target_size"] = target_size
......
import os
from typing import Any, Dict, List, Union
import requests
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available():
from PIL import Image
from ..image_utils import load_image
if is_torch_available():
import torch
......@@ -45,28 +43,6 @@ class ObjectDetectionPipeline(Pipeline):
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
pass
else:
raise ValueError(
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image
def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {}
if "threshold" in kwargs:
......@@ -105,7 +81,7 @@ class ObjectDetectionPipeline(Pipeline):
return super().__call__(*args, **kwargs)
def preprocess(self, image):
image = self.load_image(image)
image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt")
inputs["target_size"] = target_size
......
......@@ -15,6 +15,7 @@
import unittest
import datasets
import numpy as np
from transformers import is_torch_available, is_vision_available
......@@ -28,6 +29,7 @@ if is_vision_available():
import PIL.Image
from transformers import ImageFeatureExtractionMixin
from transformers.image_utils import load_image
def get_random_image(height, width):
......@@ -367,3 +369,68 @@ class ImageFeatureExtractionTester(unittest.TestCase):
# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
@require_vision
class LoadImageTester(unittest.TestCase):
def test_load_img_local(self):
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(480, 640, 3),
)
def test_load_img_rgba(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img = load_image(dataset[0]["file"]) # img with mode RGBA
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(512, 512, 3),
)
def test_load_img_la(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img = load_image(dataset[1]["file"]) # img with mode LA
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(512, 768, 3),
)
def test_load_img_l(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img = load_image(dataset[2]["file"]) # img with mode L
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(381, 225, 3),
)
def test_load_img_exif_transpose(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img_file = dataset[3]["file"]
img_without_exif_transpose = PIL.Image.open(img_file)
img_arr_without_exif_transpose = np.array(img_without_exif_transpose)
self.assertEqual(
img_arr_without_exif_transpose.shape,
(333, 500, 3),
)
img_with_exif_transpose = load_image(img_file)
img_arr_with_exif_transpose = np.array(img_with_exif_transpose)
self.assertEqual(
img_arr_with_exif_transpose.shape,
(500, 333, 3),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment