Commit 7a22bfee authored by myhloli's avatar myhloli
Browse files

refactor: enhance image margin cropping and processing for improved handling...

refactor: enhance image margin cropping and processing for improved handling of PIL and NumPy images
parent bd2c3d12
......@@ -70,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
np_array_image = images[image_index]
pil_img = images[image_index]
formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip(
......@@ -84,7 +84,7 @@ class UnimernetModel(object):
"latex": "",
}
formula_list.append(new_item)
bbox_img = np_array_image[ymin:ymax, xmin:xmax]
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list)
......
from PIL import Image, ImageOps
from transformers.image_processing_utils import BaseImageProcessor
import numpy as np
import cv2
import albumentations as alb
from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
# TODO: dereference cv2 if possible
......@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
image = self.prepare_input(item)
return self.transform(image=image)['image'][:1]
@staticmethod
def crop_margin(img: Image.Image) -> Image.Image:
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
return img.crop((a, b, w + a, h + b))
@staticmethod
def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
"""Crop margins of image using NumPy operations"""
......@@ -60,11 +77,13 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
if img is None:
return None
# try:
# img = self.crop_margin_numpy(img)
# except Exception:
# # might throw an error for broken files
# return None
# Handle numpy array
elif isinstance(img, np.ndarray):
try:
img = self.crop_margin_numpy(img)
except Exception:
# might throw an error for broken files
return None
if img.shape[0] == 0 or img.shape[1] == 0:
return None
......@@ -103,6 +122,29 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
return padded_img
# Handle PIL Image
elif isinstance(img, Image.Image):
try:
img = self.crop_margin(img.convert("RGB"))
except OSError:
# might throw an error for broken files
return None
if img.height == 0 or img.width == 0:
return None
# Resize while preserving aspect ratio
img = resize(img, min(self.input_size))
img.thumbnail((self.input_size[1], self.input_size[0]))
new_w, new_h = img.width, img.height
# Calculate and apply padding
padding = self._calculate_padding(new_w, new_h, random_padding)
return np.array(ImageOps.expand(img, padding))
else:
return None
def _calculate_padding(self, new_w, new_h, random_padding):
"""Calculate padding values for PIL images"""
delta_width = self.input_size[1] - new_w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment