"...training/git@developer.sourcefind.cn:dcuai/dlexamples.git" did not exist on "ac26d1fb3dd287ea88d5b5fe7ed629e9bb6cd875"
Unverified Commit 671569dd authored by Mishig Davaadorj's avatar Mishig Davaadorj Committed by GitHub
Browse files

Put `load_image` function in `image_utils.py` & fix image rotation issue (#14062)

* Fix img load rotation

* Add `load_image` to `image_utils.py`

* Implement LoadImageTester

* Use hf-internal-testing dataset

* Add img utils comments

* Refactor LoadImageTester

* Import load_image under is_vision_available
parent 89766b3d
...@@ -13,10 +13,14 @@ ...@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from typing import List, Union from typing import List, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import PIL.ImageOps
import requests
from .file_utils import _is_torch, is_torch_available from .file_utils import _is_torch, is_torch_available
...@@ -35,6 +39,39 @@ def is_torch_tensor(obj): ...@@ -35,6 +39,39 @@ def is_torch_tensor(obj):
return _is_torch(obj) if is_torch_available() else False return _is_torch(obj) if is_torch_available() else False
def load_image(image: Union[str, "PIL.Image.Image"]) -> "PIL.Image.Image":
"""
Loads :obj:`image` to a PIL Image.
Args:
image (:obj:`str` or :obj:`PIL.Image.Image`):
The image to convert to the PIL Image format.
Returns:
:obj:`PIL.Image.Image`: A PIL Image.
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = PIL.Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, PIL.Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
# In the future we can add a TF implementation here when we have TF models. # In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin: class ImageFeatureExtractionMixin:
""" """
......
import os
from typing import List, Union from typing import List, Union
import requests
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
...@@ -11,6 +8,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline ...@@ -11,6 +8,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
from ..image_utils import load_image
if is_torch_available(): if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
...@@ -39,35 +38,13 @@ class ImageClassificationPipeline(Pipeline): ...@@ -39,35 +38,13 @@ class ImageClassificationPipeline(Pipeline):
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING) self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
image = image
else:
raise ValueError(
"Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image
def _sanitize_parameters(self, top_k=None): def _sanitize_parameters(self, top_k=None):
postprocess_params = {} postprocess_params = {}
if top_k is not None: if top_k is not None:
postprocess_params["top_k"] = top_k postprocess_params["top_k"] = top_k
return {}, {}, postprocess_params return {}, {}, postprocess_params
def __call__(self, images: Union[str, List[str], "Image", List["Image"]], **kwargs): def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
""" """
Assign labels to the image(s) passed as inputs. Assign labels to the image(s) passed as inputs.
...@@ -99,7 +76,7 @@ class ImageClassificationPipeline(Pipeline): ...@@ -99,7 +76,7 @@ class ImageClassificationPipeline(Pipeline):
return super().__call__(images, **kwargs) return super().__call__(images, **kwargs)
def preprocess(self, image): def preprocess(self, image):
image = self.load_image(image) image = load_image(image)
model_inputs = self.feature_extractor(images=image, return_tensors="pt") model_inputs = self.feature_extractor(images=image, return_tensors="pt")
return model_inputs return model_inputs
......
import base64 import base64
import io import io
import os
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import numpy as np import numpy as np
import requests
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
...@@ -15,6 +12,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline ...@@ -15,6 +12,8 @@ from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available(): if is_vision_available():
from PIL import Image from PIL import Image
from ..image_utils import load_image
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -49,28 +48,6 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -49,28 +48,6 @@ class ImageSegmentationPipeline(Pipeline):
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING) self.check_model_type(MODEL_FOR_IMAGE_SEGMENTATION_MAPPING)
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
pass
else:
raise ValueError(
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {} postprocess_kwargs = {}
if "threshold" in kwargs: if "threshold" in kwargs:
...@@ -118,7 +95,7 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -118,7 +95,7 @@ class ImageSegmentationPipeline(Pipeline):
return torch.no_grad return torch.no_grad
def preprocess(self, image): def preprocess(self, image):
image = self.load_image(image) image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]]) target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt") inputs = self.feature_extractor(images=[image], return_tensors="pt")
inputs["target_size"] = target_size inputs["target_size"] = target_size
......
import os
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import requests
from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends from ..file_utils import add_end_docstrings, is_torch_available, is_vision_available, requires_backends
from ..utils import logging from ..utils import logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
if is_vision_available(): if is_vision_available():
from PIL import Image from ..image_utils import load_image
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -45,28 +43,6 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -45,28 +43,6 @@ class ObjectDetectionPipeline(Pipeline):
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING) self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING)
@staticmethod
def load_image(image: Union[str, "Image.Image"]):
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
image = Image.open(requests.get(image, stream=True).raw)
elif os.path.isfile(image):
image = Image.open(image)
else:
raise ValueError(
f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path"
)
elif isinstance(image, Image.Image):
pass
else:
raise ValueError(
"Incorrect format used for image. Should be a URL linking to an image, a local path, or a PIL image."
)
image = image.convert("RGB")
return image
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {} postprocess_kwargs = {}
if "threshold" in kwargs: if "threshold" in kwargs:
...@@ -105,7 +81,7 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -105,7 +81,7 @@ class ObjectDetectionPipeline(Pipeline):
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
def preprocess(self, image): def preprocess(self, image):
image = self.load_image(image) image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]]) target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt") inputs = self.feature_extractor(images=[image], return_tensors="pt")
inputs["target_size"] = target_size inputs["target_size"] = target_size
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import datasets
import numpy as np import numpy as np
from transformers import is_torch_available, is_vision_available from transformers import is_torch_available, is_vision_available
...@@ -28,6 +29,7 @@ if is_vision_available(): ...@@ -28,6 +29,7 @@ if is_vision_available():
import PIL.Image import PIL.Image
from transformers import ImageFeatureExtractionMixin from transformers import ImageFeatureExtractionMixin
from transformers.image_utils import load_image
def get_random_image(height, width): def get_random_image(height, width):
...@@ -367,3 +369,68 @@ class ImageFeatureExtractionTester(unittest.TestCase): ...@@ -367,3 +369,68 @@ class ImageFeatureExtractionTester(unittest.TestCase):
# Check result is consistent with PIL.Image.crop # Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size) cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image)))) self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))
@require_vision
class LoadImageTester(unittest.TestCase):
def test_load_img_local(self):
img = load_image("./tests/fixtures/tests_samples/COCO/000000039769.png")
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(480, 640, 3),
)
def test_load_img_rgba(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img = load_image(dataset[0]["file"]) # img with mode RGBA
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(512, 512, 3),
)
def test_load_img_la(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img = load_image(dataset[1]["file"]) # img with mode LA
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(512, 768, 3),
)
def test_load_img_l(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img = load_image(dataset[2]["file"]) # img with mode L
img_arr = np.array(img)
self.assertEqual(
img_arr.shape,
(381, 225, 3),
)
def test_load_img_exif_transpose(self):
dataset = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", "image", split="test")
img_file = dataset[3]["file"]
img_without_exif_transpose = PIL.Image.open(img_file)
img_arr_without_exif_transpose = np.array(img_without_exif_transpose)
self.assertEqual(
img_arr_without_exif_transpose.shape,
(333, 500, 3),
)
img_with_exif_transpose = load_image(img_file)
img_arr_with_exif_transpose = np.array(img_with_exif_transpose)
self.assertEqual(
img_arr_with_exif_transpose.shape,
(500, 333, 3),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment