Commit 7a22bfee authored by myhloli's avatar myhloli
Browse files

refactor: enhance image margin cropping and processing for improved handling...

refactor: enhance image margin cropping and processing for improved handling of PIL and NumPy images
parent bd2c3d12
...@@ -70,7 +70,7 @@ class UnimernetModel(object): ...@@ -70,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices # Collect images with their original indices
for image_index in range(len(images_mfd_res)): for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index] mfd_res = images_mfd_res[image_index]
np_array_image = images[image_index] pil_img = images[image_index]
formula_list = [] formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip( for idx, (xyxy, conf, cla) in enumerate(zip(
...@@ -84,7 +84,7 @@ class UnimernetModel(object): ...@@ -84,7 +84,7 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
bbox_img = np_array_image[ymin:ymax, xmin:xmax] bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
area = (xmax - xmin) * (ymax - ymin) area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list) curr_idx = len(mf_image_list)
......
from PIL import Image, ImageOps
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
import numpy as np import numpy as np
import cv2 import cv2
import albumentations as alb import albumentations as alb
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
# TODO: dereference cv2 if possible # TODO: dereference cv2 if possible
...@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
image = self.prepare_input(item) image = self.prepare_input(item)
return self.transform(image=image)['image'][:1] return self.transform(image=image)['image'][:1]
@staticmethod
def crop_margin(img: Image.Image) -> Image.Image:
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
return img.crop((a, b, w + a, h + b))
@staticmethod @staticmethod
def crop_margin_numpy(img: np.ndarray) -> np.ndarray: def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
"""Crop margins of image using NumPy operations""" """Crop margins of image using NumPy operations"""
...@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
if img is None: if img is None:
return None return None
# try: # Handle numpy array
# img = self.crop_margin_numpy(img) elif isinstance(img, np.ndarray):
# except Exception: try:
# # might throw an error for broken files img = self.crop_margin_numpy(img)
# return None except Exception:
# might throw an error for broken files
return None
if img.shape[0] == 0 or img.shape[1] == 0: if img.shape[0] == 0 or img.shape[1] == 0:
return None return None
# Get current dimensions # Get current dimensions
h, w = img.shape[:2] h, w = img.shape[:2]
target_h, target_w = self.input_size target_h, target_w = self.input_size
# Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail) # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
scale = min(target_h / h, target_w / w) scale = min(target_h / h, target_w / w)
# Calculate new dimensions # Calculate new dimensions
new_h, new_w = int(h * scale), int(w * scale) new_h, new_w = int(h * scale), int(w * scale)
# Resize the image while preserving aspect ratio # Resize the image while preserving aspect ratio
resized_img = cv2.resize(img, (new_w, new_h)) resized_img = cv2.resize(img, (new_w, new_h))
# Calculate padding values using the existing method # Calculate padding values using the existing method
delta_width = target_w - new_w delta_width = target_w - new_w
delta_height = target_h - new_h delta_height = target_h - new_h
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding) pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
# Apply padding (convert PIL padding format to OpenCV format) # Apply padding (convert PIL padding format to OpenCV format)
padding_color = [0, 0, 0] if len(img.shape) == 3 else [0] padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
padded_img = cv2.copyMakeBorder( padded_img = cv2.copyMakeBorder(
resized_img, resized_img,
pad_height, # top pad_height, # top
delta_height - pad_height, # bottom delta_height - pad_height, # bottom
pad_width, # left pad_width, # left
delta_width - pad_width, # right delta_width - pad_width, # right
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=padding_color value=padding_color
) )
return padded_img return padded_img
# Handle PIL Image
elif isinstance(img, Image.Image):
try:
img = self.crop_margin(img.convert("RGB"))
except OSError:
# might throw an error for broken files
return None
if img.height == 0 or img.width == 0:
return None
# Resize while preserving aspect ratio
img = resize(img, min(self.input_size))
img.thumbnail((self.input_size[1], self.input_size[0]))
new_w, new_h = img.width, img.height
# Calculate and apply padding
padding = self._calculate_padding(new_w, new_h, random_padding)
return np.array(ImageOps.expand(img, padding))
else:
return None
def _calculate_padding(self, new_w, new_h, random_padding): def _calculate_padding(self, new_w, new_h, random_padding):
"""Calculate padding values for PIL images""" """Calculate padding values for PIL images"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment