Unverified Commit 037a3ae6 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2763 from herryqg/master

encapsulate prediction parsing logic in DocLayoutYOLOModel
parents af7dee49 4c52a05b
from typing import List, Dict, Union
from doclayout_yolo import YOLOv10
from tqdm import tqdm
import numpy as np
from PIL import Image
class DocLayoutYOLOModel(object):
def __init__(self, weight, device):
self.model = YOLOv10(weight)
class DocLayoutYOLOModel:
def __init__(
self,
weight: str,
device: str = "cuda",
imgsz: int = 1280,
conf: float = 0.1,
iou: float = 0.45,
):
self.model = YOLOv10(weight).to(device)
self.device = device
self.imgsz = imgsz
self.conf = conf
self.iou = iou
def predict(self, image):
def _parse_prediction(self, prediction) -> List[Dict]:
layout_res = []
doclayout_yolo_res = self.model.predict(
image,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False, device=self.device
)[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu(),
# 容错处理
if not hasattr(prediction, "boxes") or prediction.boxes is None:
return layout_res
for xyxy, conf, cls in zip(
prediction.boxes.xyxy.cpu(),
prediction.boxes.conf.cpu(),
prediction.boxes.cls.cpu(),
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
coords = list(map(int, xyxy.tolist()))
xmin, ymin, xmax, ymax = coords
layout_res.append({
"category_id": int(cls.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
})
return layout_res
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
image_res.boxes.xyxy,
image_res.boxes.conf,
image_res.boxes.cls,
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
images_layout_res.append(layout_res)
def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
prediction = self.model.predict(
image,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False
)[0]
return self._parse_prediction(prediction)
return images_layout_res
def batch_predict(
self,
images: List[Union[np.ndarray, Image.Image]],
batch_size: int = 4
) -> List[List[Dict]]:
results = []
for idx in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
batch = images[idx: idx + batch_size]
predictions = self.model.predict(
batch,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False,
)
for pred in predictions:
results.append(self._parse_prediction(pred))
return results
\ No newline at end of file
from typing import List, Union
from tqdm import tqdm
from ultralytics import YOLO
import numpy as np
from PIL import Image
class YOLOv8MFDModel(object):
def __init__(self, weight, device="cpu"):
self.mfd_model = YOLO(weight)
class YOLOv8MFDModel:
def __init__(
self,
weight: str,
device: str = "cpu",
imgsz: int = 1888,
conf: float = 0.25,
iou: float = 0.45,
):
self.model = YOLO(weight).to(device)
self.device = device
self.imgsz = imgsz
self.conf = conf
self.iou = iou
def predict(self, image):
mfd_res = self.mfd_model.predict(
image, imgsz=1888, conf=0.25, iou=0.45, verbose=False, device=self.device
)[0]
return mfd_res
def _run_predict(
self,
inputs: Union[np.ndarray, Image.Image, List],
is_batch: bool = False
) -> List:
preds = self.model.predict(
inputs,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False,
device=self.device
)
return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in mfd_res:
images_mfd_res.append(image_res)
return images_mfd_res
def predict(self, image: Union[np.ndarray, Image.Image]):
return self._run_predict(image)
def batch_predict(
self,
images: List[Union[np.ndarray, Image.Image]],
batch_size: int = 4
) -> List:
results = []
for idx in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
batch = images[idx: idx + batch_size]
batch_preds = self._run_predict(batch, is_batch=True)
results.extend(batch_preds)
return results
\ No newline at end of file
......@@ -343,6 +343,14 @@
"created_at": "2025-06-18T11:27:23Z",
"repoId": 765083837,
"pullRequestNo": 2727
},
{
"name": "QIN2DIM",
"id": 62018067,
"comment_id": 2992279796,
"created_at": "2025-06-20T17:04:59Z",
"repoId": 765083837,
"pullRequestNo": 2758
}
]
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment