Unverified Commit 0c7a0882 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2611 from myhloli/dev

Dev
parents 3bd0ecf1 a392f445
# Copyright (c) Opendatalab. All rights reserved.
def remove_non_official_s3_args(s3path): def remove_non_official_s3_args(s3path):
......
# Copyright (c) Opendatalab. All rights reserved.
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
......
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
...@@ -19,7 +19,7 @@ class MathDataset(Dataset): ...@@ -19,7 +19,7 @@ class MathDataset(Dataset):
class UnimernetModel(object): class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_="cpu"): def __init__(self, weight_dir, _device_="cpu"):
from .unimernet_hf import UnimernetModel from .unimernet_hf import UnimernetModel
if _device_.startswith("mps"): if _device_.startswith("mps"):
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager") self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
...@@ -70,7 +70,7 @@ class UnimernetModel(object): ...@@ -70,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices # Collect images with their original indices
for image_index in range(len(images_mfd_res)): for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index] mfd_res = images_mfd_res[image_index]
np_array_image = images[image_index] pil_img = images[image_index]
formula_list = [] formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip( for idx, (xyxy, conf, cla) in enumerate(zip(
...@@ -84,7 +84,7 @@ class UnimernetModel(object): ...@@ -84,7 +84,7 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
bbox_img = np_array_image[ymin:ymax, xmin:xmax] bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
area = (xmax - xmin) * (ymax - ymin) area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list) curr_idx = len(mf_image_list)
......
...@@ -374,6 +374,10 @@ def latex_rm_whitespace(s: str): ...@@ -374,6 +374,10 @@ def latex_rm_whitespace(s: str):
# \qquad后补空格 # \qquad后补空格
s = QQUAD_PATTERN.sub(r'\\qquad ', s) s = QQUAD_PATTERN.sub(r'\\qquad ', s)
# 如果字符串以反斜杠结尾,去掉最后的反斜杠
while s.endswith('\\'):
s = s[:-1]
return s return s
......
from PIL import Image, ImageOps
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
import numpy as np import numpy as np
import cv2 import cv2
import albumentations as alb import albumentations as alb
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
# TODO: dereference cv2 if possible # TODO: dereference cv2 if possible
...@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -27,6 +29,21 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
image = self.prepare_input(item) image = self.prepare_input(item)
return self.transform(image=image)['image'][:1] return self.transform(image=image)['image'][:1]
@staticmethod
def crop_margin(img: Image.Image) -> Image.Image:
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
return img.crop((a, b, w + a, h + b))
@staticmethod @staticmethod
def crop_margin_numpy(img: np.ndarray) -> np.ndarray: def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
"""Crop margins of image using NumPy operations""" """Crop margins of image using NumPy operations"""
...@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -60,48 +77,73 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
if img is None: if img is None:
return None return None
# try: # Handle numpy array
# img = self.crop_margin_numpy(img) elif isinstance(img, np.ndarray):
# except Exception: try:
# # might throw an error for broken files img = self.crop_margin_numpy(img)
# return None except Exception:
# might throw an error for broken files
return None
if img.shape[0] == 0 or img.shape[1] == 0: if img.shape[0] == 0 or img.shape[1] == 0:
return None return None
# Get current dimensions # Get current dimensions
h, w = img.shape[:2] h, w = img.shape[:2]
target_h, target_w = self.input_size target_h, target_w = self.input_size
# Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail) # Calculate scale to preserve aspect ratio (equivalent to resize + thumbnail)
scale = min(target_h / h, target_w / w) scale = min(target_h / h, target_w / w)
# Calculate new dimensions # Calculate new dimensions
new_h, new_w = int(h * scale), int(w * scale) new_h, new_w = int(h * scale), int(w * scale)
# Resize the image while preserving aspect ratio # Resize the image while preserving aspect ratio
resized_img = cv2.resize(img, (new_w, new_h)) resized_img = cv2.resize(img, (new_w, new_h))
# Calculate padding values using the existing method # Calculate padding values using the existing method
delta_width = target_w - new_w delta_width = target_w - new_w
delta_height = target_h - new_h delta_height = target_h - new_h
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding) pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
# Apply padding (convert PIL padding format to OpenCV format) # Apply padding (convert PIL padding format to OpenCV format)
padding_color = [0, 0, 0] if len(img.shape) == 3 else [0] padding_color = [0, 0, 0] if len(img.shape) == 3 else [0]
padded_img = cv2.copyMakeBorder( padded_img = cv2.copyMakeBorder(
resized_img, resized_img,
pad_height, # top pad_height, # top
delta_height - pad_height, # bottom delta_height - pad_height, # bottom
pad_width, # left pad_width, # left
delta_width - pad_width, # right delta_width - pad_width, # right
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=padding_color value=padding_color
) )
return padded_img return padded_img
# Handle PIL Image
elif isinstance(img, Image.Image):
try:
img = self.crop_margin(img.convert("RGB"))
except OSError:
# might throw an error for broken files
return None
if img.height == 0 or img.width == 0:
return None
# Resize while preserving aspect ratio
img = resize(img, min(self.input_size))
img.thumbnail((self.input_size[1], self.input_size[0]))
new_w, new_h = img.width, img.height
# Calculate and apply padding
padding = self._calculate_padding(new_w, new_h, random_padding)
return np.array(ImageOps.expand(img, padding))
else:
return None
def _calculate_padding(self, new_w, new_h, random_padding): def _calculate_padding(self, new_w, new_h, random_padding):
"""Calculate padding values for PIL images""" """Calculate padding values for PIL images"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment