Unverified Commit a9b37b71 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'opendatalab:dev' into dev

parents 3c69c569 ec566d22
......@@ -150,7 +150,10 @@ def doc_analyze(
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
if lang is None or lang == 'auto':
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
else:
images_with_extra_info = [(images[index], ocr, lang) for index in range(len(dataset))]
if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE
......@@ -160,7 +163,7 @@ def doc_analyze(
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
results.extend(result)
model_json = []
......@@ -214,7 +217,7 @@ def batch_doc_analyze(
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
results = []
for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
_, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result)
infer_results = []
......@@ -237,7 +240,6 @@ def may_batch_image_analyze(
idx: int,
ocr: bool,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None):
......@@ -248,9 +250,6 @@ def may_batch_image_analyze(
from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton()
custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
images = [image for image, _, _ in images_with_extra_info]
batch_analyze = False
......@@ -276,64 +275,15 @@ def may_batch_image_analyze(
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_analyze = True
# batch_analyze = True
elif str(device).startswith('mps'):
batch_analyze = True
doc_analyze_start = time.time()
# batch_analyze = True
pass
if batch_analyze:
"""# batch analyze
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info)
"""
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
else:
# single analyze
"""
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
doc_analyze_start = time.time()
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
results = []
for img_idx, img in enumerate(images):
inference_start = time.time()
result = custom_model(img)
logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
results.append(result)
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info)
gc_start = time.time()
clean_memory(get_device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment