"tests/pytorch/vscode:/vscode.git/clone" did not exist on "0a65559276b7ae27df7b0a43dd2579a0e1ce9be4"
Commit 4fd1e41e authored by Suven's avatar Suven
Browse files

fix: batch methods in DocLayoutYOLO and YOLOv8 models

parent 8ccfff6f
...@@ -34,15 +34,15 @@ class BatchAnalyze: ...@@ -34,15 +34,15 @@ class BatchAnalyze:
self.batch_ratio = batch_ratio self.batch_ratio = batch_ratio
def __call__(self, images: list) -> list: def __call__(self, images: list) -> list:
images_layout_res = []
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3: if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3 # layoutlmv3
images_layout_res = []
for image in images: for image in images:
layout_res = self.model.layout_model(image, ignore_catids=[]) layout_res = self.model.layout_model(image, ignore_catids=[])
images_layout_res.append(layout_res) images_layout_res.append(layout_res)
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
images_layout_res = self.model.layout_model.batch_predict( images_layout_res += self.model.layout_model.batch_predict(
images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
) )
...@@ -148,6 +148,8 @@ class BatchAnalyze: ...@@ -148,6 +148,8 @@ class BatchAnalyze:
) )
logger.info(f"table time: {round(time.time() - table_start, 2)}") logger.info(f"table time: {round(time.time() - table_start, 2)}")
return images_layout_res
def doc_batch_analyze( def doc_batch_analyze(
dataset: Dataset, dataset: Dataset,
......
...@@ -28,14 +28,17 @@ class DocLayoutYOLOModel(object): ...@@ -28,14 +28,17 @@ class DocLayoutYOLOModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = [] images_layout_res = []
for index in range(0, len(images), batch_size): for index in range(0, len(images), batch_size):
doclayout_yolo_res = self.model.predict( doclayout_yolo_res = [
images[index : index + batch_size], image_res.cpu()
imgsz=1024, for image_res in self.model.predict(
conf=0.25, images[index : index + batch_size],
iou=0.45, imgsz=1024,
verbose=True, conf=0.25,
device=self.device, iou=0.45,
).cpu() verbose=True,
device=self.device,
)
]
for image_res in doclayout_yolo_res: for image_res in doclayout_yolo_res:
layout_res = [] layout_res = []
for xyxy, conf, cla in zip( for xyxy, conf, cla in zip(
......
...@@ -15,14 +15,17 @@ class YOLOv8MFDModel(object): ...@@ -15,14 +15,17 @@ class YOLOv8MFDModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = [] images_mfd_res = []
for index in range(0, len(images), batch_size): for index in range(0, len(images), batch_size):
mfd_res = self.mfd_model.predict( mfd_res = [
images[index : index + batch_size], image_res.cpu()
imgsz=1888, for image_res in self.mfd_model.predict(
conf=0.25, images[index : index + batch_size],
iou=0.45, imgsz=1888,
verbose=True, conf=0.25,
device=self.device, iou=0.45,
).cpu() verbose=True,
device=self.device,
)
]
for image_res in mfd_res: for image_res in mfd_res:
images_mfd_res.append(image_res) images_mfd_res.append(image_res)
return images_mfd_res return images_mfd_res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment