Commit e9c24739 authored by icecraft's avatar icecraft
Browse files

style: remove unused code

parent ecdd162f
...@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL ...@@ -34,8 +34,6 @@ from magic_pdf.model.model_list import MODEL
# from magic_pdf.operators.models import InferenceResult # from magic_pdf.operators.models import InferenceResult
MIN_BATCH_INFERENCE_SIZE = 100
class ModelSingleton: class ModelSingleton:
_instance = None _instance = None
_models = {} _models = {}
...@@ -143,17 +141,14 @@ def doc_analyze( ...@@ -143,17 +141,14 @@ def doc_analyze(
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
one_shot: bool = True,
): ):
end_page_id = ( end_page_id = (
end_page_id end_page_id
if end_page_id is not None and end_page_id >= 0 if end_page_id is not None and end_page_id >= 0
else len(dataset) - 1 else len(dataset) - 1
) )
parallel_count = None
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
images = [] images = []
page_wh_list = [] page_wh_list = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -163,41 +158,16 @@ def doc_analyze( ...@@ -163,41 +158,16 @@ def doc_analyze(
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE: if len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None: batch_size = MIN_BATCH_INFERENCE_SIZE
parallel_count = 2 # should check the gpu memory firstly ! batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else: else:
_, results = may_batch_image_analyze( batch_images = [images]
images,
0, results = []
ocr, for sn, batch_image in enumerate(batch_images):
show_log, _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
lang, layout_model, formula_enable, table_enable) results.extend(result)
model_json = [] model_json = []
for index in range(len(dataset)): for index in range(len(dataset)):
...@@ -224,11 +194,8 @@ def batch_doc_analyze( ...@@ -224,11 +194,8 @@ def batch_doc_analyze(
layout_model=None, layout_model=None,
formula_enable=None, formula_enable=None,
table_enable=None, table_enable=None,
one_shot: bool = True,
): ):
parallel_count = None MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
if os.environ.get('MINERU_PARALLEL_INFERENCE_COUNT'):
parallel_count = int(os.environ['MINERU_PARALLEL_INFERENCE_COUNT'])
images = [] images = []
page_wh_list = [] page_wh_list = []
for dataset in datasets: for dataset in datasets:
...@@ -238,40 +205,17 @@ def batch_doc_analyze( ...@@ -238,40 +205,17 @@ def batch_doc_analyze(
images.append(img_dict['img']) images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height'])) page_wh_list.append((img_dict['width'], img_dict['height']))
if one_shot and len(images) >= MIN_BATCH_INFERENCE_SIZE: if len(images) >= MIN_BATCH_INFERENCE_SIZE:
if parallel_count is None: batch_size = MIN_BATCH_INFERENCE_SIZE
parallel_count = 2 # should check the gpu memory firstly ! batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
# split images into parallel_count batches
if parallel_count > 1:
batch_size = (len(images) + parallel_count - 1) // parallel_count
batch_images = [images[i:i+batch_size] for i in range(0, len(images), batch_size)]
else:
batch_images = [images]
results = []
parallel_count = len(batch_images) # adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results = []
pool = mp.Pool(processes=parallel_count)
mapped_results = pool.starmap(may_batch_image_analyze, [(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)])
for sn, result in mapped_results:
results.extend(result)
else: else:
_, results = may_batch_image_analyze( batch_images = [images]
images,
0, results = []
ocr,
show_log, for sn, batch_image in enumerate(batch_images):
lang, layout_model, formula_enable, table_enable) _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
results.extend(result)
infer_results = [] infer_results = []
from magic_pdf.operators.models import InferenceResult from magic_pdf.operators.models import InferenceResult
......
...@@ -314,7 +314,7 @@ def batch_do_parse( ...@@ -314,7 +314,7 @@ def batch_do_parse(
dss.append(PymuDocDataset(v, lang=lang)) dss.append(PymuDocDataset(v, lang=lang))
else: else:
dss.append(v) dss.append(v)
infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable, one_shot=True) infer_results = batch_doc_analyze(dss, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
for idx, infer_result in enumerate(infer_results): for idx, infer_result in enumerate(infer_results):
_do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox) _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment