Unverified Commit f6bc4f70 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2003 from icecraft/feat/batch_analyze_with_ocr_and_lang

feat: batch inference with ocr and lang flag
parents 2c8470b0 bbba2a12
...@@ -17,13 +17,25 @@ MFR_BASE_BATCH_SIZE = 16 ...@@ -17,13 +17,25 @@ MFR_BASE_BATCH_SIZE = 16
class BatchAnalyze: class BatchAnalyze:
def __init__(self, model: CustomPEKModel, batch_ratio: int): def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable):
self.model = model self.model_manager = model_manager
self.batch_ratio = batch_ratio self.batch_ratio = batch_ratio
self.show_log = show_log
def __call__(self, images: list) -> list: self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def __call__(self, images_with_extra_info: list) -> list:
if len(images_with_extra_info) == 0:
return []
images_layout_res = [] images_layout_res = []
layout_start_time = time.time() layout_start_time = time.time()
_, fst_ocr, fst_lang = images_with_extra_info[0]
self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable)
images = [image for image, _, _ in images_with_extra_info]
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3: if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3 # layoutlmv3
for image in images: for image in images:
...@@ -79,6 +91,8 @@ class BatchAnalyze: ...@@ -79,6 +91,8 @@ class BatchAnalyze:
table_count = 0 table_count = 0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)): for index in range(len(images)):
_, ocr_enable, _lang = images_with_extra_info[index]
self.model = self.model_manager.get_model(ocr_enable, self.show_log, _lang, self.layout_model, self.formula_enable, self.table_enable)
layout_res = images_layout_res[index] layout_res = images_layout_res[index]
np_array_img = images[index] np_array_img = images[index]
...@@ -99,7 +113,7 @@ class BatchAnalyze: ...@@ -99,7 +113,7 @@ class BatchAnalyze:
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.model.apply_ocr: if ocr_enable:
ocr_res = self.model.ocr_model.ocr( ocr_res = self.model.ocr_model.ocr(
new_image, mfd_res=adjusted_mfdetrec_res new_image, mfd_res=adjusted_mfdetrec_res
)[0] )[0]
...@@ -159,9 +173,7 @@ class BatchAnalyze: ...@@ -159,9 +173,7 @@ class BatchAnalyze:
table_count += len(table_res_list) table_count += len(table_res_list)
if self.model.apply_ocr: if self.model.apply_ocr:
logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}') logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}')
else:
logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
if self.model.apply_table: if self.model.apply_table:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}') logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
......
...@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 ...@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from loguru import logger from loguru import logger
from magic_pdf.model.sub_modules.model_utils import get_vram from magic_pdf.model.sub_modules.model_utils import get_vram
from magic_pdf.config.enums import SupportedPdfParseMethod
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
...@@ -150,12 +150,13 @@ def doc_analyze( ...@@ -150,12 +150,13 @@ def doc_analyze(
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
if len(images) >= MIN_BATCH_INFERENCE_SIZE: if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE batch_size = MIN_BATCH_INFERENCE_SIZE
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)] batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
else: else:
batch_images = [images] batch_images = [images_with_extra_info]
results = [] results = []
for sn, batch_image in enumerate(batch_images): for sn, batch_image in enumerate(batch_images):
...@@ -181,7 +182,7 @@ def doc_analyze( ...@@ -181,7 +182,7 @@ def doc_analyze(
def batch_doc_analyze( def batch_doc_analyze(
datasets: list[Dataset], datasets: list[Dataset],
ocr: bool = False, parse_method: str,
show_log: bool = False, show_log: bool = False,
lang=None, lang=None,
layout_model=None, layout_model=None,
...@@ -192,47 +193,31 @@ def batch_doc_analyze( ...@@ -192,47 +193,31 @@ def batch_doc_analyze(
batch_size = MIN_BATCH_INFERENCE_SIZE batch_size = MIN_BATCH_INFERENCE_SIZE
images = [] images = []
page_wh_list = [] page_wh_list = []
lang_list = []
lang_s = set() images_with_extra_info = []
for dataset in datasets: for dataset in datasets:
for index in range(len(dataset)): for index in range(len(dataset)):
if lang is None or lang == 'auto': if lang is None or lang == 'auto':
lang_list.append(dataset._lang) _lang = dataset._lang
else: else:
lang_list.append(lang) _lang = lang
lang_s.add(lang_list[-1])
page_data = dataset.get_page(index) page_data = dataset.get_page(index)
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if parse_method == 'auto':
images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
else:
images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
batch_images = [] batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
img_idx_list = [] results = []
for t_lang in lang_s: for sn, batch_image in enumerate(batch_images):
tmp_img_idx_list = [] _, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
for i, _lang in enumerate(lang_list): results.extend(result)
if _lang == t_lang:
tmp_img_idx_list.append(i)
img_idx_list.extend(tmp_img_idx_list)
if batch_size >= len(tmp_img_idx_list):
batch_images.append((t_lang, [images[j] for j in tmp_img_idx_list]))
else:
slices = [tmp_img_idx_list[k:k+batch_size] for k in range(0, len(tmp_img_idx_list), batch_size)]
for arr in slices:
batch_images.append((t_lang, [images[j] for j in arr]))
unorder_results = []
for sn, (_lang, batch_image) in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, _lang, layout_model, formula_enable, table_enable)
unorder_results.extend(result)
results = [None] * len(img_idx_list)
for i, idx in enumerate(img_idx_list):
results[idx] = unorder_results[i]
infer_results = [] infer_results = []
from magic_pdf.operators.models import InferenceResult from magic_pdf.operators.models import InferenceResult
for index in range(len(datasets)): for index in range(len(datasets)):
dataset = datasets[index] dataset = datasets[index]
...@@ -248,9 +233,9 @@ def batch_doc_analyze( ...@@ -248,9 +233,9 @@ def batch_doc_analyze(
def may_batch_image_analyze( def may_batch_image_analyze(
images: list[np.ndarray], images_with_extra_info: list[(np.ndarray, bool, str)],
idx: int, idx: int,
ocr: bool = False, ocr: bool,
show_log: bool = False, show_log: bool = False,
lang=None, lang=None,
layout_model=None, layout_model=None,
...@@ -267,6 +252,7 @@ def may_batch_image_analyze( ...@@ -267,6 +252,7 @@ def may_batch_image_analyze(
ocr, show_log, lang, layout_model, formula_enable, table_enable ocr, show_log, lang, layout_model, formula_enable, table_enable
) )
images = [image for image, _, _ in images_with_extra_info]
batch_analyze = False batch_analyze = False
batch_ratio = 1 batch_ratio = 1
device = get_device() device = get_device()
...@@ -306,8 +292,8 @@ def may_batch_image_analyze( ...@@ -306,8 +292,8 @@ def may_batch_image_analyze(
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
""" """
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images) results = batch_model(images_with_extra_info)
""" """
for index in range(len(dataset)): for index in range(len(dataset)):
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
......
...@@ -314,21 +314,10 @@ def batch_do_parse( ...@@ -314,21 +314,10 @@ def batch_do_parse(
dss.append(PymuDocDataset(v, lang=lang)) dss.append(PymuDocDataset(v, lang=lang))
else: else:
dss.append(v) dss.append(v)
dss_with_fn = list(zip(dss, pdf_file_names))
if parse_method == 'auto': infer_results = batch_doc_analyze(dss, parse_method, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
dss_typed_txt = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.TXT]
dss_typed_ocr = [(i, x) for i, x in enumerate(dss_with_fn) if x[0].classify() == SupportedPdfParseMethod.OCR]
infer_results = [None] * len(dss_with_fn)
infer_results_txt = batch_doc_analyze([x[1][0] for x in dss_typed_txt], lang=lang, ocr=False, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
infer_results_ocr = batch_doc_analyze([x[1][0] for x in dss_typed_ocr], lang=lang, ocr=True, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
for i, infer_res in enumerate(infer_results_txt):
infer_results[dss_typed_txt[i][0]] = infer_res
for i, infer_res in enumerate(infer_results_ocr):
infer_results[dss_typed_ocr[i][0]] = infer_res
else:
infer_results = batch_doc_analyze(dss, lang=lang, ocr=parse_method == 'ocr', layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
for idx, infer_result in enumerate(infer_results): for idx, infer_result in enumerate(infer_results):
_do_parse(output_dir, dss_with_fn[idx][1], dss_with_fn[idx][0], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang) _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto']) parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment