"LICENSE_code" did not exist on "cfce9fbfe7b0c4e70213a9cebbd01056e19375e4"
Unverified Commit ec566d22 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2004 from icecraft/feat/remove_old_inference_code

feat: remove old inference code
parents f6bc4f70 4fbc3689
...@@ -150,7 +150,10 @@ def doc_analyze( ...@@ -150,7 +150,10 @@ def doc_analyze(
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if lang is None or lang == 'auto':
images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))] images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
else:
images_with_extra_info = [(images[index], ocr, lang) for index in range(len(dataset))]
if len(images) >= MIN_BATCH_INFERENCE_SIZE: if len(images) >= MIN_BATCH_INFERENCE_SIZE:
batch_size = MIN_BATCH_INFERENCE_SIZE batch_size = MIN_BATCH_INFERENCE_SIZE
...@@ -160,7 +163,7 @@ def doc_analyze( ...@@ -160,7 +163,7 @@ def doc_analyze(
results = [] results = []
for sn, batch_image in enumerate(batch_images): for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
results.extend(result) results.extend(result)
model_json = [] model_json = []
...@@ -214,7 +217,7 @@ def batch_doc_analyze( ...@@ -214,7 +217,7 @@ def batch_doc_analyze(
batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)] batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
results = [] results = []
for sn, batch_image in enumerate(batch_images): for sn, batch_image in enumerate(batch_images):
_, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable) _, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
results.extend(result) results.extend(result)
infer_results = [] infer_results = []
...@@ -237,7 +240,6 @@ def may_batch_image_analyze( ...@@ -237,7 +240,6 @@ def may_batch_image_analyze(
idx: int, idx: int,
ocr: bool, ocr: bool,
show_log: bool = False, show_log: bool = False,
lang=None,
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None): table_enable=None):
...@@ -248,9 +250,6 @@ def may_batch_image_analyze( ...@@ -248,9 +250,6 @@ def may_batch_image_analyze(
from magic_pdf.model.batch_analyze import BatchAnalyze from magic_pdf.model.batch_analyze import BatchAnalyze
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
images = [image for image, _, _ in images_with_extra_info] images = [image for image, _, _ in images_with_extra_info]
batch_analyze = False batch_analyze = False
...@@ -276,64 +275,15 @@ def may_batch_image_analyze( ...@@ -276,64 +275,15 @@ def may_batch_image_analyze(
else: else:
batch_ratio = 1 batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_analyze = True # batch_analyze = True
elif str(device).startswith('mps'): elif str(device).startswith('mps'):
batch_analyze = True # batch_analyze = True
pass
doc_analyze_start = time.time() doc_analyze_start = time.time()
if batch_analyze:
"""# batch analyze
images = []
page_wh_list = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable) batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
results = batch_model(images_with_extra_info) results = batch_model(images_with_extra_info)
"""
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
page_width, page_height = page_wh_list.pop(0)
else:
result = []
page_height = 0
page_width = 0
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
else:
# single analyze
"""
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
"""
results = []
for img_idx, img in enumerate(images):
inference_start = time.time()
result = custom_model(img)
logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
results.append(result)
gc_start = time.time() gc_start = time.time()
clean_memory(get_device()) clean_memory(get_device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment