Unverified Commit 55e9bb9b authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1284 from IMSUVEN/dev

parents 8ccfff6f be010394
...@@ -34,25 +34,62 @@ class BatchAnalyze: ...@@ -34,25 +34,62 @@ class BatchAnalyze:
self.batch_ratio = batch_ratio self.batch_ratio = batch_ratio
def __call__(self, images: list) -> list: def __call__(self, images: list) -> list:
images_layout_res = []
layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3: if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3 # layoutlmv3
images_layout_res = []
for image in images: for image in images:
layout_res = self.model.layout_model(image, ignore_catids=[]) layout_res = self.model.layout_model(image, ignore_catids=[])
images_layout_res.append(layout_res) images_layout_res.append(layout_res)
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
images_layout_res = self.model.layout_model.batch_predict( layout_images = []
images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE modified_images = []
for image_index, image in enumerate(images):
pil_img = Image.fromarray(image)
width, height = pil_img.size
if height > width:
input_res = {"poly": [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
)
layout_images.append(new_image)
modified_images.append([image_index, useful_list])
else:
layout_images.append(pil_img)
images_layout_res += self.model.layout_model.batch_predict(
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
)
for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res["poly"])):
if i % 2 == 0:
res["poly"][i] = (
res["poly"][i] - useful_list[0] + useful_list[2]
)
else:
res["poly"][i] = (
res["poly"][i] - useful_list[1] + useful_list[3]
)
logger.info(
f"layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}"
) )
if self.model.apply_formula: if self.model.apply_formula:
# 公式检测 # 公式检测
mfd_start_time = time.time()
images_mfd_res = self.model.mfd_model.batch_predict( images_mfd_res = self.model.mfd_model.batch_predict(
images, self.batch_ratio * MFD_BASE_BATCH_SIZE images, self.batch_ratio * MFD_BASE_BATCH_SIZE
) )
logger.info(
f"mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}"
)
# 公式识别 # 公式识别
mfr_start_time = time.time()
images_formula_list = self.model.mfr_model.batch_predict( images_formula_list = self.model.mfr_model.batch_predict(
images_mfd_res, images_mfd_res,
images, images,
...@@ -60,10 +97,17 @@ class BatchAnalyze: ...@@ -60,10 +97,17 @@ class BatchAnalyze:
) )
for image_index in range(len(images)): for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index] images_layout_res[image_index] += images_formula_list[image_index]
logger.info(
f"mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}"
)
# 清理显存 # 清理显存
clean_vram(self.model.device, vram_threshold=8) clean_vram(self.model.device, vram_threshold=8)
ocr_time = 0
ocr_count = 0
table_time = 0
table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)): for index in range(len(images)):
layout_res = images_layout_res[index] layout_res = images_layout_res[index]
...@@ -99,12 +143,8 @@ class BatchAnalyze: ...@@ -99,12 +143,8 @@ class BatchAnalyze:
if ocr_res: if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list) ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list) layout_res.extend(ocr_result_list)
ocr_time += time.time() - ocr_start
ocr_cost = round(time.time() - ocr_start, 2) ocr_count += len(ocr_res_list)
if self.model.apply_ocr:
logger.info(f"ocr time: {ocr_cost}")
else:
logger.info(f"det time: {ocr_cost}")
# 表格识别 table recognition # 表格识别 table recognition
if self.model.apply_table: if self.model.apply_table:
...@@ -146,7 +186,17 @@ class BatchAnalyze: ...@@ -146,7 +186,17 @@ class BatchAnalyze:
logger.warning( logger.warning(
"table recognition processing fails, not get html return" "table recognition processing fails, not get html return"
) )
logger.info(f"table time: {round(time.time() - table_start, 2)}") table_time += time.time() - table_start
table_count += len(table_res_list)
if self.model.apply_ocr:
logger.info(f"ocr time: {round(ocr_time, 2)}, image num: {ocr_count}")
else:
logger.info(f"det time: {round(ocr_time, 2)}, image num: {ocr_count}")
if self.model.apply_table:
logger.info(f"table time: {round(table_time, 2)}, image num: {table_count}")
return images_layout_res
def doc_batch_analyze( def doc_batch_analyze(
...@@ -223,6 +273,8 @@ def doc_batch_analyze( ...@@ -223,6 +273,8 @@ def doc_batch_analyze(
model_json.append(page_dict) model_json.append(page_dict)
# TODO: clean memory when gpu memory is not enough # TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory() clean_memory()
logger.info(f"clean memory time: {round(time.time() - clean_memory_start_time, 2)}")
return InferenceResult(model_json, dataset) return InferenceResult(model_json, dataset)
...@@ -28,14 +28,17 @@ class DocLayoutYOLOModel(object): ...@@ -28,14 +28,17 @@ class DocLayoutYOLOModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = [] images_layout_res = []
for index in range(0, len(images), batch_size): for index in range(0, len(images), batch_size):
doclayout_yolo_res = self.model.predict( doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size], images[index : index + batch_size],
imgsz=1024, imgsz=1024,
conf=0.25, conf=0.25,
iou=0.45, iou=0.45,
verbose=True, verbose=True,
device=self.device, device=self.device,
).cpu() )
]
for image_res in doclayout_yolo_res: for image_res in doclayout_yolo_res:
layout_res = [] layout_res = []
for xyxy, conf, cla in zip( for xyxy, conf, cla in zip(
......
...@@ -15,14 +15,17 @@ class YOLOv8MFDModel(object): ...@@ -15,14 +15,17 @@ class YOLOv8MFDModel(object):
def batch_predict(self, images: list, batch_size: int) -> list: def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = [] images_mfd_res = []
for index in range(0, len(images), batch_size): for index in range(0, len(images), batch_size):
mfd_res = self.mfd_model.predict( mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
images[index : index + batch_size], images[index : index + batch_size],
imgsz=1888, imgsz=1888,
conf=0.25, conf=0.25,
iou=0.45, iou=0.45,
verbose=True, verbose=True,
device=self.device, device=self.device,
).cpu() )
]
for image_res in mfd_res: for image_res in mfd_res:
images_mfd_res.append(image_res) images_mfd_res.append(image_res)
return images_mfd_res return images_mfd_res
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment