Unverified Commit 0c7a0882 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2611 from myhloli/dev

Dev
parents 3bd0ecf1 a392f445
# Copyright (c) Opendatalab. All rights reserved.
def remove_non_official_s3_args(s3path):
......
# Copyright (c) Opendatalab. All rights reserved.
from pydantic import BaseModel, Field
......
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
......@@ -19,7 +19,7 @@ class MathDataset(Dataset):
class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_="cpu"):
def __init__(self, weight_dir, _device_="cpu"):
from .unimernet_hf import UnimernetModel
if _device_.startswith("mps"):
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
......@@ -70,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
np_array_image = images[image_index]
pil_img = images[image_index]
formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip(
......@@ -84,7 +84,7 @@ class UnimernetModel(object):
"latex": "",
}
formula_list.append(new_item)
bbox_img = np_array_image[ymin:ymax, xmin:xmax]
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list)
......
......@@ -374,6 +374,10 @@ def latex_rm_whitespace(s: str):
# \qquad后补空格
s = QQUAD_PATTERN.sub(r'\\qquad ', s)
# 如果字符串以反斜杠结尾,去掉最后的反斜杠
while s.endswith('\\'):
s = s[:-1]
return s
......
from PIL import Image, ImageOps
from transformers.image_processing_utils import BaseImageProcessor
import numpy as np
import cv2
import albumentations as alb
from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
# TODO: dereference cv2 if possible
......@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
image = self.prepare_input(item)
return self.transform(image=image)['image'][:1]
@staticmethod
def crop_margin(img: Image.Image) -> Image.Image:
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
return img.crop((a, b, w + a, h + b))
@staticmethod
def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
"""Crop margins of image using NumPy operations"""
......@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
if img is None:
return None
# try:
# img = self.crop_margin_numpy(img)
# except Exception:
# # might throw an error for broken files
# return None
# Handle numpy array
elif isinstance(img, np.ndarray):
try:
img = self.crop_margin_numpy(img)
except Exception:
# might throw an error for broken files
return None
if img.shape[0] == 0 or img.shape[1] == 0:
return None
if img.shape[0] == 0 or img.shape[1] == 0:
return None
# Get current dimensions
h, w = img.shape[:2]
target_h, target_w = self.input_size
# Get current dimensions
h, w = img.shape[:2]
target_h, target_w = self.input_size
# Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
scale = min(target_h / h, target_w / w)
# Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
scale = min(target_h / h, target_w / w)
# Calculate new dimensions
new_h, new_w = int(h * scale), int(w * scale)
# Calculate new dimensions
new_h, new_w = int(h * scale), int(w * scale)
# Resize the image while preserving aspect ratio
resized_img = cv2.resize(img, (new_w, new_h))
# Resize the image while preserving aspect ratio
resized_img = cv2.resize(img, (new_w, new_h))
# Calculate padding values using the existing method
delta_width = target_w - new_w
delta_height = target_h - new_h
# Calculate padding values using the existing method
delta_width = target_w - new_w
delta_height = target_h - new_h
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
# Apply padding (convert PIL padding format to OpenCV format)
padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
padded_img = cv2.copyMakeBorder(
resized_img,
pad_height, # top
delta_height - pad_height, # bottom
pad_width, # left
delta_width - pad_width, # right
cv2.BORDER_CONSTANT,
value=padding_color
)
# Apply padding (convert PIL padding format to OpenCV format)
padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
padded_img = cv2.copyMakeBorder(
resized_img,
pad_height, # top
delta_height - pad_height, # bottom
pad_width, # left
delta_width - pad_width, # right
cv2.BORDER_CONSTANT,
value=padding_color
)
return padded_img
return padded_img
# Handle PIL Image
elif isinstance(img, Image.Image):
try:
img = self.crop_margin(img.convert("RGB"))
except OSError:
# might throw an error for broken files
return None
if img.height == 0 or img.width == 0:
return None
# Resize while preserving aspect ratio
img = resize(img, min(self.input_size))
img.thumbnail((self.input_size[1], self.input_size[0]))
new_w, new_h = img.width, img.height
# Calculate and apply padding
padding = self._calculate_padding(new_w, new_h, random_padding)
return np.array(ImageOps.expand(img, padding))
else:
return None
def _calculate_padding(self, new_w, new_h, random_padding):
"""Calculate padding values for PIL images"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment