Commit 7a22bfee authored by myhloli's avatar myhloli
Browse files

refactor: enhance image margin cropping and processing for improved handling...

refactor: enhance image margin cropping and processing for improved handling of PIL and NumPy images
parent bd2c3d12
...@@ -70,7 +70,7 @@ class UnimernetModel(object): ...@@ -70,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices # Collect images with their original indices
for image_index in range(len(images_mfd_res)): for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index] mfd_res = images_mfd_res[image_index]
np_array_image = images[image_index] pil_img = images[image_index]
formula_list = [] formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip( for idx, (xyxy, conf, cla) in enumerate(zip(
...@@ -84,7 +84,7 @@ class UnimernetModel(object): ...@@ -84,7 +84,7 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
bbox_img = np_array_image[ymin:ymax, xmin:xmax] bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
area = (xmax - xmin) * (ymax - ymin) area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list) curr_idx = len(mf_image_list)
......
from PIL import Image, ImageOps
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
import numpy as np import numpy as np
import cv2 import cv2
import albumentations as alb import albumentations as alb
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
# TODO: dereference cv2 if possible # TODO: dereference cv2 if possible
...@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
image = self.prepare_input(item) image = self.prepare_input(item)
return self.transform(image=image)['image'][:1] return self.transform(image=image)['image'][:1]
@staticmethod
def crop_margin(img: Image.Image) -> Image.Image:
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
return img.crop((a, b, w + a, h + b))
@staticmethod @staticmethod
def crop_margin_numpy(img: np.ndarray) -> np.ndarray: def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
"""Crop margins of image using NumPy operations""" """Crop margins of image using NumPy operations"""
...@@ -60,11 +77,13 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -60,11 +77,13 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
if img is None: if img is None:
return None return None
# try: # Handle numpy array
# img = self.crop_margin_numpy(img) elif isinstance(img, np.ndarray):
# except Exception: try:
# # might throw an error for broken files img = self.crop_margin_numpy(img)
# return None except Exception:
# might throw an error for broken files
return None
if img.shape[0] == 0 or img.shape[1] == 0: if img.shape[0] == 0 or img.shape[1] == 0:
return None return None
...@@ -103,6 +122,29 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -103,6 +122,29 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
return padded_img return padded_img
# Handle PIL Image
elif isinstance(img, Image.Image):
try:
img = self.crop_margin(img.convert("RGB"))
except OSError:
# might throw an error for broken files
return None
if img.height == 0 or img.width == 0:
return None
# Resize while preserving aspect ratio
img = resize(img, min(self.input_size))
img.thumbnail((self.input_size[1], self.input_size[0]))
new_w, new_h = img.width, img.height
# Calculate and apply padding
padding = self._calculate_padding(new_w, new_h, random_padding)
return np.array(ImageOps.expand(img, padding))
else:
return None
def _calculate_padding(self, new_w, new_h, random_padding): def _calculate_padding(self, new_w, new_h, random_padding):
"""Calculate padding values for PIL images""" """Calculate padding values for PIL images"""
delta_width = self.input_size[1] - new_w delta_width = self.input_size[1] - new_w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment