Commit 0063a668 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
from typing import List, Optional, Union, Tuple
import cv2
import numpy as np
from supervision.detection.core import Detections
from supervision.draw.color import Color, ColorPalette
class BoxAnnotator:
"""
A class for drawing bounding boxes on an image using detections provided.
Attributes:
color (Union[Color, ColorPalette]): The color to draw the bounding box,
can be a single color or a color palette
thickness (int): The thickness of the bounding box lines, default is 2
text_color (Color): The color of the text on the bounding box, default is white
text_scale (float): The scale of the text on the bounding box, default is 0.5
text_thickness (int): The thickness of the text on the bounding box,
default is 1
text_padding (int): The padding around the text on the bounding box,
default is 5
"""
def __init__(
self,
color: Union[Color, ColorPalette] = ColorPalette.DEFAULT,
thickness: int = 3, # 1 for seeclick 2 for mind2web and 3 for demo
text_color: Color = Color.BLACK,
text_scale: float = 0.5, # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
text_thickness: int = 2, #1, # 2 for demo
text_padding: int = 10,
avoid_overlap: bool = True,
):
self.color: Union[Color, ColorPalette] = color
self.thickness: int = thickness
self.text_color: Color = text_color
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
self.text_padding: int = text_padding
self.avoid_overlap: bool = avoid_overlap
def annotate(
self,
scene: np.ndarray,
detections: Detections,
labels: Optional[List[str]] = None,
skip_label: bool = False,
image_size: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
"""
Draws bounding boxes on the frame using the detections provided.
Args:
scene (np.ndarray): The image on which the bounding boxes will be drawn
detections (Detections): The detections for which the
bounding boxes will be drawn
labels (Optional[List[str]]): An optional list of labels
corresponding to each detection. If `labels` are not provided,
corresponding `class_id` will be used as label.
skip_label (bool): Is set to `True`, skips bounding box label annotation.
Returns:
np.ndarray: The image with the bounding boxes drawn on it
Example:
```python
import supervision as sv
classes = ['person', ...]
image = ...
detections = sv.Detections(...)
box_annotator = sv.BoxAnnotator()
labels = [
f"{classes[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _ in detections
]
annotated_frame = box_annotator.annotate(
scene=image.copy(),
detections=detections,
labels=labels
)
```
"""
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(len(detections)):
x1, y1, x2, y2 = detections.xyxy[i].astype(int)
class_id = (
detections.class_id[i] if detections.class_id is not None else None
)
idx = class_id if class_id is not None else i
color = (
self.color.by_idx(idx)
if isinstance(self.color, ColorPalette)
else self.color
)
cv2.rectangle(
img=scene,
pt1=(x1, y1),
pt2=(x2, y2),
color=color.as_bgr(),
thickness=self.thickness,
)
if skip_label:
continue
text = (
f"{class_id}"
if (labels is None or len(detections) != len(labels))
else labels[i]
)
text_width, text_height = cv2.getTextSize(
text=text,
fontFace=font,
fontScale=self.text_scale,
thickness=self.text_thickness,
)[0]
if not self.avoid_overlap:
text_x = x1 + self.text_padding
text_y = y1 - self.text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * self.text_padding - text_height
text_background_x2 = x1 + 2 * self.text_padding + text_width
text_background_y2 = y1
# text_x = x1 - self.text_padding - text_width
# text_y = y1 + self.text_padding + text_height
# text_background_x1 = x1 - 2 * self.text_padding - text_width
# text_background_y1 = y1
# text_background_x2 = x1
# text_background_y2 = y1 + 2 * self.text_padding + text_height
else:
text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2 = get_optimal_label_pos(self.text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size)
cv2.rectangle(
img=scene,
pt1=(text_background_x1, text_background_y1),
pt2=(text_background_x2, text_background_y2),
color=color.as_bgr(),
thickness=cv2.FILLED,
)
# import pdb; pdb.set_trace()
box_color = color.as_rgb()
luminance = 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2]
text_color = (0,0,0) if luminance > 160 else (255,255,255)
cv2.putText(
img=scene,
text=text,
org=(text_x, text_y),
fontFace=font,
fontScale=self.text_scale,
# color=self.text_color.as_rgb(),
color=text_color,
thickness=self.text_thickness,
lineType=cv2.LINE_AA,
)
return scene
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2, return_max=True):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
if return_max:
return max(intersection / union, ratio1, ratio2)
else:
return intersection / union
def get_optimal_label_pos(text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size):
""" check overlap of text and background detection box, and get_optimal_label_pos,
pos: str, position of the text, must be one of 'top left', 'top right', 'outer left', 'outer right' TODO: if all are overlapping, return the last one, i.e. outer right
Threshold: default to 0.3
"""
def get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size):
is_overlap = False
for i in range(len(detections)):
detection = detections.xyxy[i].astype(int)
if IoU([text_background_x1, text_background_y1, text_background_x2, text_background_y2], detection) > 0.3:
is_overlap = True
break
# check if the text is out of the image
if text_background_x1 < 0 or text_background_x2 > image_size[0] or text_background_y1 < 0 or text_background_y2 > image_size[1]:
is_overlap = True
return is_overlap
# if pos == 'top left':
text_x = x1 + text_padding
text_y = y1 - text_padding
text_background_x1 = x1
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x1 + 2 * text_padding + text_width
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer left':
text_x = x1 - text_padding - text_width
text_y = y1 + text_padding + text_height
text_background_x1 = x1 - 2 * text_padding - text_width
text_background_y1 = y1
text_background_x2 = x1
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'outer right':
text_x = x2 + text_padding
text_y = y1 + text_padding + text_height
text_background_x1 = x2
text_background_y1 = y1
text_background_x2 = x2 + 2 * text_padding + text_width
text_background_y2 = y1 + 2 * text_padding + text_height
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
# elif pos == 'top right':
text_x = x2 - text_padding - text_width
text_y = y1 - text_padding
text_background_x1 = x2 - 2 * text_padding - text_width
text_background_y1 = y1 - 2 * text_padding - text_height
text_background_x2 = x2
text_background_y2 = y1
is_overlap = get_is_overlap(detections, text_background_x1, text_background_y1, text_background_x2, text_background_y2, image_size)
if not is_overlap:
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
return text_x, text_y, text_background_x1, text_background_y1, text_background_x2, text_background_y2
from util.utils import get_som_labeled_img, get_caption_model_processor, get_yolo_model, check_ocr_box
import torch
from PIL import Image
import io
import base64
from typing import Dict
class Omniparser(object):
def __init__(self, config: Dict):
self.config = config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.som_model = get_yolo_model(model_path=config['som_model_path'])
self.caption_model_processor = get_caption_model_processor(model_name=config['caption_model_name'], model_name_or_path=config['caption_model_path'], device=device)
print('Omniparser initialized!!!')
def parse(self, image_base64: str):
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
print('image size:', image.size)
box_overlay_ratio = max(image.size) / 3200
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
(text, ocr_bbox), _ = check_ocr_box(image, display_img=False, output_bb_format='xyxy', easyocr_args={'text_threshold': 0.8}, use_paddleocr=False)
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image, self.som_model, BOX_TRESHOLD = self.config['BOX_TRESHOLD'], output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=self.caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
return dino_labled_img, parsed_content_list
\ No newline at end of file
import re
# is instruction English
def is_english_simple(text):
try:
text.encode(encoding='utf-8').decode('ascii')
except UnicodeDecodeError:
return False
else:
return True
# bbox -> point (str)
def bbox_2_point(bbox, dig=2):
# bbox [left, top, right, bottom]
point = [(bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2]
point = [f"{item:.2f}" for item in point]
point_str = "({},{})".format(point[0], point[1])
return point_str
# bbox -> bbox (str)
def bbox_2_bbox(bbox, dig=2):
bbox = [f"{item:.2f}" for item in bbox]
bbox_str = "({},{},{},{})".format(bbox[0], bbox[1], bbox[2], bbox[3])
return bbox_str
# point (str) -> point
def pred_2_point(s):
floats = re.findall(r'-?\d+\.?\d*', s)
floats = [float(num) for num in floats]
if len(floats) == 2:
click_point = floats
elif len(floats) == 4:
click_point = [(floats[0]+floats[2])/2, (floats[1]+floats[3])/2]
return click_point
# bbox (qwen str) -> bbox
def extract_bbox(s):
# Regular expression to find the content inside <box> and </box>
pattern = r"<box>\((\d+,\d+)\),\((\d+,\d+)\)</box>"
matches = re.findall(pattern, s)
# Convert the tuples of strings into tuples of integers
return [(int(x.split(',')[0]), int(x.split(',')[1])) for x in sum(matches, ())]
def extract_mark_id(s):
match = re.search(r'Mark: (\d+)', s)
if match:
return int(match.group(1))
return None
\ No newline at end of file
import torch
from ultralytics import YOLO
from PIL import Image
import io
import base64
device = 'cuda'
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import networkx as nx
# import cv2
font_path = "agents/ui_agent/util/arial.ttf"
class MarkHelper:
def __init__(self):
self.markSize_dict = {}
self.font_dict = {}
self.min_font_size = 20 # 1 in v1
self.max_font_size = 30
self.max_font_proportion = 0.04 # 0.032 in v1
def __get_markSize(self, text, image_height, image_width, font):
im = Image.new('RGB', (image_width, image_height))
draw = ImageDraw.Draw(im)
_, _, width, height = draw.textbbox((0, 0), text=text, font=font)
return height, width
def _setup_new_font(self, image_height, image_width):
key = f"{image_height}_{image_width}"
# print(f"Setting up new font for image size: {key}")
# setup the font
fontsize = self.min_font_size
font = ImageFont.truetype(font_path, fontsize)
# font = ImageFont.load_default(size=fontsize)
while min(self.__get_markSize("555", image_height, image_width, font)) < min(self.max_font_size, self.max_font_proportion * min(image_height, image_width)):
# iterate until the text size is just larger than the criteria
fontsize += 1
font = ImageFont.truetype(font_path, fontsize)
# font = ImageFont.load_default(size=fontsize)
self.font_dict[key] = font
# setup the markSize dict
markSize_3digits = self.__get_markSize('555', image_height, image_width, font)
markSize_2digits = self.__get_markSize('55', image_height, image_width, font)
markSize_1digit = self.__get_markSize('5', image_height, image_width, font)
self.markSize_dict[key] = {
1: markSize_1digit,
2: markSize_2digits,
3: markSize_3digits
}
def get_font(self, image_height, image_width):
key = f"{image_height}_{image_width}"
if key not in self.font_dict:
self._setup_new_font(image_height, image_width)
return self.font_dict[key]
def get_mark_size(self, text_str, image_height, image_width):
"""Get the font size for the given image dimensions."""
key = f"{image_height}_{image_width}"
if key not in self.markSize_dict:
self._setup_new_font(image_height, image_width)
largest_size = self.markSize_dict[key].get(3, None)
text_h, text_w = self.markSize_dict[key].get(len(text_str), largest_size) # default to the largest size if the text is too long
return text_h, text_w
def __calculate_iou(box1, box2, return_area=False):
"""
Calculate the Intersection over Union (IoU) of two bounding boxes.
:param box1: Tuple of (y, x, h, w) for the first bounding box
:param box2: Tuple of (y, x, h, w) for the second bounding box
:return: IoU value
"""
y1, x1, h1, w1 = box1
y2, x2, h2, w2 = box2
# Calculate the intersection area
y_min = max(y1, y2)
x_min = max(x1, x2)
y_max = min(y1 + h1, y2 + h2)
x_max = min(x1 + w1, x2 + w2)
intersection_area = max(0, y_max - y_min) * max(0, x_max - x_min)
# Compute the area of both bounding boxes
box1_area = h1 * w1
box2_area = h2 * w2
# Calculate the IoU
# iou = intersection_area / box1_area + box2_area - intersection_area
iou = intersection_area / (min(box1_area, box2_area) + 0.0001)
if return_area:
return iou, intersection_area
return iou
def __calculate_nearest_corner_distance(box1, box2):
"""Calculate the distance between the nearest edge or corner of two bounding boxes."""
y1, x1, h1, w1 = box1
y2, x2, h2, w2 = box2
corners1 = np.array([
[y1, x1],
[y1, x1 + w1],
[y1 + h1, x1],
[y1 + h1, x1 + w1]
])
corners2 = np.array([
[y2, x2],
[y2, x2 + w2],
[y2 + h2, x2],
[y2 + h2, x2 + w2]
])
# Calculate pairwise distances between corners
distances = np.linalg.norm(corners1[:, np.newaxis] - corners2, axis=2)
# Find the minimum distance
min_distance = np.min(distances)
return min_distance
def _find_least_overlapping_corner(bbox, bboxes, drawn_boxes, text_size, image_size):
"""Find the corner with the least overlap with other bboxes.
Args:
bbox: (y, x, h, w) The bounding box to place the text on.
bboxes: [(y, x, h, w)] The list of bounding boxes to compare against.
drawn_boxes: [(y, x, h, w)] The list of bounding boxes that have already been drawn on.
text_size: (height, width) The size of the text to be drawn.
image_size: (height, width) The size of the image.
"""
y, x, h, w = bbox
h_text, w_text = text_size
image_height, image_width = image_size
corners = [
# top-left
(y - h_text, x),
# top-right
(y - h_text, x + w - w_text),
# right-top
(y, x + w),
# right-bottom
(y + h - h_text, x + w),
# bottom-right
(y + h, x + w - w_text),
# bottom-left
(y + h, x),
# left-bottom
(y + h - h_text, x - w_text),
# left-top
(y, x - w_text),
]
best_corner = corners[0]
max_flag = float('inf')
for corner in corners:
corner_bbox = (corner[0], corner[1], h_text, w_text)
# if the corner is out of the image, skip
if corner[0] < 0 or corner[1] < 0 or corner[0] + h_text > image_height or corner[1] + w_text > image_width:
continue
max_iou = - (image_width + image_height)
# 找到关于这个角最差的 case
# given the current corner, find the larget iou with other bboxes.
for other_bbox in bboxes + drawn_boxes:
if np.array_equal(bbox, other_bbox):
continue
iou = __calculate_iou(corner_bbox, other_bbox, return_area=True)[1]
max_iou = max(max_iou, iou - 0.0001 * __calculate_nearest_corner_distance(corner_bbox, other_bbox))
# the smaller the max_IOU, the better the corner
# 取最差的值 相对最好的那个角
if max_iou < max_flag:
max_flag = max_iou
best_corner = corner
return best_corner
def plot_boxes_with_marks(
image: Image.Image,
bboxes, # (y, x, h, w)
mark_helper: MarkHelper,
linewidth=2,
alpha=0,
edgecolor=None,
fn_save=None,
normalized_to_pixel=True,
add_mark=True
) -> np.ndarray:
"""Plots bounding boxes on an image with marks attached to the edges of the boxes where no overlap with other boxes occurs.
Args:
image: The image to plot the bounding boxes on.
bboxes: A 2D int array of shape (num_boxes, 4), where each row represents a bounding box: (y_top_left, x_top_left, box_height, box_width). If normalized_to_pixel is True, the values are float and are normalized with the image size. If normalized_to_pixel is False, the values are int and are in pixel.
"""
# Then modify the drawing code
draw = ImageDraw.Draw(image)
# draw boxes on the image
image_width, image_height = image.size
if normalized_to_pixel:
bboxes = [(int(y * image_height), int(x * image_width), int(h * image_height), int(w * image_width)) for y, x, h, w in bboxes]
for box in bboxes:
y, x, h, w = box
draw.rectangle([x, y, x + w, y + h], outline=edgecolor, width=linewidth)
# Draw the bounding boxes with index at the least overlapping corner
drawn_boxes = []
for idx, bbox in enumerate(bboxes):
text = str(idx)
text_h, text_w = mark_helper.get_mark_size(text, image_height, image_width)
corner_y, corner_x = _find_least_overlapping_corner(
bbox, bboxes, drawn_boxes, (text_h, text_w), (image_height, image_width))
# Define the index box (y, x, y + h, x + w)
text_box = (corner_y, corner_x, text_h, text_w)
if add_mark:
# Draw the filled index box and text
draw.rectangle([corner_x, corner_y, corner_x + text_w, corner_y + text_h], # (x, y, x + w, y + h)
fill="red")
font = mark_helper.get_font(image_height, image_width)
draw.text((corner_x, corner_y), text, fill='white', font=font)
# Update the list of drawn boxes
drawn_boxes.append(np.array(text_box))
if fn_save is not None: # PIL image
image.save(fn_save)
return image
def plot_circles_with_marks(
image: Image.Image,
points, # (x, y)
mark_helper: MarkHelper,
linewidth=2,
edgecolor=None,
fn_save=None,
normalized_to_pixel=True,
add_mark=True
) -> np.ndarray:
"""Plots bounding boxes on an image with marks attached to the edges of the boxes where no overlap with other boxes occurs.
Args:
image: The image to plot the bounding boxes on.
bboxes: A 2D int array of shape (num_boxes, 4), where each row represents a bounding box: (y_top_left, x_top_left, box_height, box_width). If normalized_to_pixel is True, the values are float and are normalized with the image size. If normalized_to_pixel is False, the values are int and are in pixel.
"""
# draw boxes on the image
image_width, image_height = image.size
if normalized_to_pixel:
bboxes = [(int(y * image_height), int(x * image_width), int(h * image_height), int(w * image_width)) for y, x, h, w in bboxes]
draw = ImageDraw.Draw(image)
for point in points:
x, y = point
draw.circle((x, y), radius=5, outline=edgecolor, width=linewidth)
if fn_save is not None: # PIL image
image.save(fn_save)
return image
markhelper = MarkHelper()
BBOX_DEDUPLICATION_IOU_PROPORTION = 0.5
BBOX_GROUPING_VERTICAL_THRESHOLD = 20
BBOX_GROUPING_HORIZONTAL_THRESHOLD = 20
BBOX_AUG_TARGET = 2.0
def _is_boxes_same_line_or_near(bbox1, bbox2, vertical_threshold, horizontal_threshold):
"""check if two boxes are in the same line or close enough to be considered together"""
y1, x1, h1, w1 = bbox1
y2, x2, h2, w2 = bbox2
# Check if the boxes are close horizontally (consider the edge case where the boxes are touching)
horizontally_close = (x1 <= x2 and x2 - x1 <= w1 + horizontal_threshold) or (x2 <= x1 and x1 - x2 <= w2 + horizontal_threshold)
# Check if the boxes are close vertically (consider the edge case where the boxes are touching)
vertically_close = (y1 <= y2 and y2 - y1 <= h1 + vertical_threshold) or (y2 <= y1 and y1 - y2 <= h2 + vertical_threshold)
# Consider the boxes to be in the same line if they are vertically close and either overlap or are close horizontally
return vertically_close and horizontally_close
def _build_adjacency_matrix(bboxes, vertical_threshold, horizontal_threshold):
"""Build the adjacency matrix based on the merging criteria."""
num_boxes = len(bboxes)
A = np.zeros((num_boxes, num_boxes), dtype=int)
for i in range(num_boxes):
for j in range(i + 1, num_boxes):
if _is_boxes_same_line_or_near(bboxes[i], bboxes[j], vertical_threshold, horizontal_threshold):
A[i, j] = 1
A[j, i] = 1 # Symmetric matrix
return A
def merge_connected_bboxes(bboxes, text_details,
vertical_threshold=BBOX_GROUPING_VERTICAL_THRESHOLD,
horizontal_threshold=BBOX_GROUPING_HORIZONTAL_THRESHOLD
):
"""Merge bboxes based on the adjacency matrix and return merged bboxes.
Args:
bboxes: A 2D array of shape (num_boxes, 4), where each row represents a bounding box: (y, x, height, width).
text_details: A list of text details for each bounding box.
vertical_threshold: The maximum vertical distance between two boxes to be considered in the same line.
horizontal_threshold: The maximum horizontal distance between two boxes to be considered close.
"""
# return if there are no bboxes
if len(bboxes) <= 1:
return bboxes, text_details
# Convert bboxes (x1, y1, x2, y2) to (y, x, height, width) format
bboxes = np.array(bboxes)
bboxes = np.array([bboxes[:, 1], bboxes[:, 0], bboxes[:, 3] - bboxes[:, 1], bboxes[:, 2] - bboxes[:, 0]]).T
# Build adjacency matrix
A = _build_adjacency_matrix(bboxes, vertical_threshold, horizontal_threshold)
# Create graph from adjacency matrix
G = nx.from_numpy_array(A)
# Find connected components
components = list(nx.connected_components(G))
# Convert bboxes to (y_min, x_min, y_max, x_max) format
corners = np.copy(bboxes)
corners_y, corners_x, corners_h, corners_w = corners[:, 0], corners[:, 1], corners[:, 2], corners[:, 3]
corners_y_max = corners_y + corners_h
corners_x_max = corners_x + corners_w
# Merge bboxes for each connected component
merged_bboxes = []
merged_text_details = []
for component in components:
indices = list(component) # e.g., [32, 33, 34, 30, 31]
indices = sorted(indices)
# merge the text details
merged_text_details.append(' '.join([text_details[i] for i in indices]))
# merge the bboxes
y_min = min(corners_y[i] for i in indices)
x_min = min(corners_x[i] for i in indices)
y_max = max(corners_y_max[i] for i in indices)
x_max = max(corners_x_max[i] for i in indices)
merged_bboxes.append((y_min, x_min, y_max - y_min, x_max - x_min)) # Convert merged_bbox back to (y, x, height, width) format
# convert (y, x, height, width) to (x1, y1, x2, y2) format without np.array
merged_bboxes = [(bbox[1], bbox[0], bbox[1] + bbox[3], bbox[0] + bbox[2]) for bbox in merged_bboxes]
return merged_bboxes, merged_text_details
\ No newline at end of file
# from ultralytics import YOLO
import os
import io
import base64
import time
from PIL import Image, ImageDraw, ImageFont
import json
import requests
# utility function
import os
import json
import sys
import os
import cv2
import numpy as np
# %matplotlib inline
from matplotlib import pyplot as plt
import easyocr
from paddleocr import PaddleOCR
reader = easyocr.Reader(['en'])
paddle_ocr = PaddleOCR(
lang='en', # other lang also available
use_angle_cls=False,
use_gpu=False, # using cuda will conflict with pytorch in the same process
show_log=False,
max_batch_size=1024,
use_dilation=True, # improves accuracy
det_db_score_mode='slow', # improves accuracy
rec_batch_num=1024)
import time
import base64
import os
import ast
import torch
from typing import Tuple, List, Union
from torchvision.ops import box_convert
import re
from torchvision.transforms import ToPILImage
import supervision as sv
import torchvision.transforms as T
from util.box_annotator import BoxAnnotator
def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_name == "blip2":
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
if device == 'cpu':
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float32
)
else:
model = Blip2ForConditionalGeneration.from_pretrained(
model_name_or_path, device_map=None, torch_dtype=torch.float16
).to(device)
elif model_name == "florence2":
from transformers import AutoProcessor, AutoModelForCausalLM
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
return {'model': model.to(device), 'processor': processor}
def get_yolo_model(model_path):
from ultralytics import YOLO
# Load the model.
model = YOLO(model_path)
return model
@torch.inference_mode()
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
to_pil = ToPILImage()
if starting_idx:
non_ocr_boxes = filtered_boxes[starting_idx:]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
for i, coord in enumerate(non_ocr_boxes):
try:
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
cropped_image = cv2.resize(cropped_image, (64, 64))
croped_pil_image.append(to_pil(cropped_image))
except:
continue
model, processor = caption_model_processor['model'], caption_model_processor['processor']
if not prompt:
if 'florence' in model.config.name_or_path:
prompt = "<CAPTION>"
else:
prompt = "The image shows"
generated_texts = []
device = model.device
# batch_size = 64
for i in range(0, len(croped_pil_image), batch_size):
start = time.time()
batch = croped_pil_image[i:i+batch_size]
t1 = time.time()
if model.device.type == 'cuda':
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt", do_resize=False).to(device=device, dtype=torch.float16)
else:
inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
# if 'florence' in model.config.name_or_path:
generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=20,num_beams=1, do_sample=False)
# else:
# generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
generated_text = [gen.strip() for gen in generated_text]
generated_texts.extend(generated_text)
return generated_texts
def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
to_pil = ToPILImage()
if ocr_bbox:
non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
else:
non_ocr_boxes = filtered_boxes
croped_pil_image = []
for i, coord in enumerate(non_ocr_boxes):
xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
cropped_image = image_source[ymin:ymax, xmin:xmax, :]
croped_pil_image.append(to_pil(cropped_image))
model, processor = caption_model_processor['model'], caption_model_processor['processor']
device = model.device
messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
batch_size = 5 # Number of samples per batch
generated_texts = []
for i in range(0, len(croped_pil_image), batch_size):
images = croped_pil_image[i:i+batch_size]
image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
texts = [prompt] * len(images)
for i, txt in enumerate(texts):
input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
inputs['input_ids'].append(input['input_ids'])
inputs['attention_mask'].append(input['attention_mask'])
inputs['pixel_values'].append(input['pixel_values'])
inputs['image_sizes'].append(input['image_sizes'])
max_len = max([x.shape[1] for x in inputs['input_ids']])
for i, v in enumerate(inputs['input_ids']):
inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
generation_args = {
"max_new_tokens": 25,
"temperature": 0.01,
"do_sample": False,
}
generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
# # remove input tokens
generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
response = [res.strip('\n').strip() for res in response]
generated_texts.extend(response)
return generated_texts
def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.95
boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1 in enumerate(boxes):
# if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
is_valid_box = True
for j, box2 in enumerate(boxes):
# keep the smaller box
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
# add the following 2 lines to include ocr bbox
if ocr_bbox:
# only add the box if it does not overlap with any ocr bbox
if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)):
filtered_boxes.append(box1)
else:
filtered_boxes.append(box1)
return torch.tensor(filtered_boxes)
def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None):
'''
ocr_bbox format: [{'type': 'text', 'bbox':[x,y], 'interactivity':False, 'content':str }, ...]
boxes format: [{'type': 'icon', 'bbox':[x,y], 'interactivity':True, 'content':None }, ...]
'''
assert ocr_bbox is None or isinstance(ocr_bbox, List)
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
def intersection_area(box1, box2):
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
return max(0, x2 - x1) * max(0, y2 - y1)
def IoU(box1, box2):
intersection = intersection_area(box1, box2)
union = box_area(box1) + box_area(box2) - intersection + 1e-6
if box_area(box1) > 0 and box_area(box2) > 0:
ratio1 = intersection / box_area(box1)
ratio2 = intersection / box_area(box2)
else:
ratio1, ratio2 = 0, 0
return max(intersection / union, ratio1, ratio2)
def is_inside(box1, box2):
# return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3]
intersection = intersection_area(box1, box2)
ratio1 = intersection / box_area(box1)
return ratio1 > 0.80
# boxes = boxes.tolist()
filtered_boxes = []
if ocr_bbox:
filtered_boxes.extend(ocr_bbox)
# print('ocr_bbox!!!', ocr_bbox)
for i, box1_elem in enumerate(boxes):
box1 = box1_elem['bbox']
is_valid_box = True
for j, box2_elem in enumerate(boxes):
# keep the smaller box
box2 = box2_elem['bbox']
if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
is_valid_box = False
break
if is_valid_box:
if ocr_bbox:
# keep yolo boxes + prioritize ocr label
box_added = False
ocr_labels = ''
for box3_elem in ocr_bbox:
if not box_added:
box3 = box3_elem['bbox']
if is_inside(box3, box1): # ocr inside icon
# box_added = True
# delete the box3_elem from ocr_bbox
try:
# gather all ocr labels
ocr_labels += box3_elem['content'] + ' '
filtered_boxes.remove(box3_elem)
except:
continue
# break
elif is_inside(box1, box3): # icon inside ocr, don't added this icon box, no need to check other ocr bbox bc no overlap between ocr bbox, icon can only be in one ocr box
box_added = True
break
else:
continue
if not box_added:
if ocr_labels:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': ocr_labels,})
else:
filtered_boxes.append({'type': 'icon', 'bbox': box1_elem['bbox'], 'interactivity': True, 'content': None, })
else:
filtered_boxes.append(box1)
return filtered_boxes # torch.tensor(filtered_boxes)
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
"""
This function annotates an image with bounding boxes and labels.
Parameters:
image_source (np.ndarray): The source image to be annotated.
boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
phrases (List[str]): A list of labels for each bounding box.
text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
Returns:
np.ndarray: The annotated image.
"""
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
detections = sv.Detections(xyxy=xyxy)
labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
annotated_frame = image_source.copy()
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
return annotated_frame, label_coordinates
def predict(model, image, caption, box_threshold, text_threshold):
""" Use huggingface model to replace the original model
"""
model, processor = model['model'], model['processor']
device = model.device
inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold, # 0.4,
text_threshold=text_threshold, # 0.3,
target_sizes=[image.size[::-1]]
)[0]
boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
return boxes, logits, phrases
def predict_yolo(model, image, box_threshold, imgsz, scale_img, iou_threshold=0.7):
""" Use huggingface model to replace the original model
"""
# model = model['model']
if scale_img:
result = model.predict(
source=image,
conf=box_threshold,
imgsz=imgsz,
iou=iou_threshold, # default 0.7
)
else:
result = model.predict(
source=image,
conf=box_threshold,
iou=iou_threshold, # default 0.7
)
boxes = result[0].boxes.xyxy#.tolist() # in pixel space
conf = result[0].boxes.conf
phrases = [str(i) for i in range(len(boxes))]
return boxes, conf, phrases
def int_box_area(box, w, h):
x1, y1, x2, y2 = box
int_box = [int(x1*w), int(y1*h), int(x2*w), int(y2*h)]
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
return area
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
"""Process either an image path or Image object
Args:
image_source: Either a file path (str) or PIL Image object
...
"""
if isinstance(image_source, str):
image_source = Image.open(image_source).convert("RGB")
w, h = image_source.size
if not imgsz:
imgsz = (h, w)
# print('image size:', w, h)
xyxy, logits, phrases = predict_yolo(model=model, image=image_source, box_threshold=BOX_TRESHOLD, imgsz=imgsz, scale_img=scale_img, iou_threshold=0.1)
xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
image_source = np.asarray(image_source)
phrases = [str(i) for i in range(len(phrases))]
# annotate the image with labels
if ocr_bbox:
ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
ocr_bbox=ocr_bbox.tolist()
else:
print('no ocr bbox!!!')
ocr_bbox = None
ocr_bbox_elem = [{'type': 'text', 'bbox':box, 'interactivity':False, 'content':txt,} for box, txt in zip(ocr_bbox, ocr_text) if int_box_area(box, w, h) > 0]
xyxy_elem = [{'type': 'icon', 'bbox':box, 'interactivity':True, 'content':None} for box in xyxy.tolist() if int_box_area(box, w, h) > 0]
filtered_boxes = remove_overlap_new(boxes=xyxy_elem, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox_elem)
# sort the filtered_boxes so that the one with 'content': None is at the end, and get the index of the first 'content': None
filtered_boxes_elem = sorted(filtered_boxes, key=lambda x: x['content'] is None)
# get the index of the first 'content': None
starting_idx = next((i for i, box in enumerate(filtered_boxes_elem) if box['content'] is None), -1)
filtered_boxes = torch.tensor([box['bbox'] for box in filtered_boxes_elem])
print('len(filtered_boxes):', len(filtered_boxes), starting_idx)
# get parsed icon local semantics
time1 = time.time()
if use_local_semantics:
caption_model = caption_model_processor['model']
if 'phi3_v' in caption_model.config.model_type:
parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
else:
parsed_content_icon = get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=prompt,batch_size=batch_size)
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
icon_start = len(ocr_text)
parsed_content_icon_ls = []
# fill the filtered_boxes_elem None content with parsed_content_icon in order
for i, box in enumerate(filtered_boxes_elem):
if box['content'] is None:
box['content'] = parsed_content_icon.pop(0)
for i, txt in enumerate(parsed_content_icon):
parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
parsed_content_merged = ocr_text + parsed_content_icon_ls
else:
ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
parsed_content_merged = ocr_text
print('time to get parsed content:', time.time()-time1)
filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
phrases = [i for i in range(len(filtered_boxes))]
# draw boxes
if draw_bbox_config:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
else:
annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
pil_img = Image.fromarray(annotated_frame)
buffered = io.BytesIO()
pil_img.save(buffered, format="PNG")
encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
if output_coord_in_ratio:
label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
return encoded_image, label_coordinates, filtered_boxes_elem
def get_xywh(input):
x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def get_xyxy(input):
x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
x, y, xp, yp = int(x), int(y), int(xp), int(yp)
return x, y, xp, yp
def get_xywh_yolo(input):
x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
x, y, w, h = int(x), int(y), int(w), int(h)
return x, y, w, h
def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
if isinstance(image_source, str):
image_source = Image.open(image_source)
if image_source.mode == 'RGBA':
# Convert RGBA to RGB to avoid alpha channel issues
image_source = image_source.convert('RGB')
image_np = np.array(image_source)
w, h = image_source.size
if use_paddleocr:
if easyocr_args is None:
text_threshold = 0.5
else:
text_threshold = easyocr_args['text_threshold']
result = paddle_ocr.ocr(image_np, cls=False)[0]
coord = [item[0] for item in result if item[1][1] > text_threshold]
text = [item[1][0] for item in result if item[1][1] > text_threshold]
else: # EasyOCR
if easyocr_args is None:
easyocr_args = {}
result = reader.readtext(image_np, **easyocr_args)
coord = [item[0] for item in result]
text = [item[1] for item in result]
if display_img:
opencv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
bb = []
for item in coord:
x, y, a, b = get_xywh(item)
bb.append((x, y, a, b))
cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
# matplotlib expects RGB
plt.imshow(cv2.cvtColor(opencv_img, cv2.COLOR_BGR2RGB))
else:
if output_bb_format == 'xywh':
bb = [get_xywh(item) for item in coord]
elif output_bb_format == 'xyxy':
bb = [get_xyxy(item) for item in coord]
return (text, bb), goal_filtering
# datasets
from .epic import epic
from .ego4d import ego4d
from .openx import openx
from .openx_magma import openx_magma
from .magma import magma
from .llava import llava
from .seeclick import seeclick
# (joint) datasets
from .dataset import build_joint_dataset
# data collators
from .data_collator import DataCollatorForSupervisedDataset
from .data_collator import DataCollatorForHFDataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment