Unverified Commit 0f91fcf6 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

feat(cli&analyze&pipeline): add start_page and end_page args for pagination (#507)

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.
parent 6f58eeab
...@@ -103,20 +103,31 @@ def custom_model_init(ocr: bool = False, show_log: bool = False): ...@@ -103,20 +103,31 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
return custom_model return custom_model
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False): def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
start_page_id=0, end_page_id=None):
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log) custom_model = model_manager.get_model(ocr, show_log)
images = load_images_from_pdf(pdf_bytes) images = load_images_from_pdf(pdf_bytes)
end_page_id = end_page_id if end_page_id else len(images) - 1
if end_page_id > len(images) - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = len(images) - 1
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
for index, img_dict in enumerate(images): for index, img_dict in enumerate(images):
img = img_dict["img"] img = img_dict["img"]
page_width = img_dict["width"] page_width = img_dict["width"]
page_height = img_dict["height"] page_height = img_dict["height"]
result = custom_model(img) if start_page_id <= index <= end_page_id:
result = custom_model(img)
else:
result = []
page_info = {"page_no": index, "height": page_height, "width": page_width} page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info} page_dict = {"layout_dets": result, "page_info": page_info}
model_json.append(page_dict) model_json.append(page_dict)
......
...@@ -210,11 +210,14 @@ def pdf_parse_union(pdf_bytes, ...@@ -210,11 +210,14 @@ def pdf_parse_union(pdf_bytes,
'''根据输入的起始范围解析pdf''' '''根据输入的起始范围解析pdf'''
end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1 end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
if end_page_id > len(pdf_docs) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
end_page_id = len(pdf_docs) - 1
'''初始化启动时间''' '''初始化启动时间'''
start_time = time.time() start_time = time.time()
for page_id in range(start_page_id, end_page_id + 1): for page_id, page in enumerate(pdf_docs):
'''debug时输出每页解析的耗时''' '''debug时输出每页解析的耗时'''
if debug_mode: if debug_mode:
time_now = time.time() time_now = time.time()
...@@ -224,7 +227,14 @@ def pdf_parse_union(pdf_bytes, ...@@ -224,7 +227,14 @@ def pdf_parse_union(pdf_bytes,
start_time = time_now start_time = time_now
'''解析pdf中的每一页''' '''解析pdf中的每一页'''
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode) if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
else:
page_w = page.rect.width
page_h = page.rect.height
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
[], [], [], [],
True, "skip page")
pdf_info_dict[f"page_{page_id}"] = page_info pdf_info_dict[f"page_{page_id}"] = page_info
"""分段""" """分段"""
......
...@@ -16,12 +16,15 @@ class AbsPipe(ABC): ...@@ -16,12 +16,15 @@ class AbsPipe(ABC):
PIP_OCR = "ocr" PIP_OCR = "ocr"
PIP_TXT = "txt" PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False): def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None):
self.pdf_bytes = pdf_bytes self.pdf_bytes = pdf_bytes
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
self.pdf_mid_data = None # 未压缩 self.pdf_mid_data = None # 未压缩
self.is_debug = is_debug self.is_debug = is_debug
self.start_page_id = start_page_id
self.end_page_id = end_page_id
def get_compress_pdf_mid_data(self): def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data) return JsonCompressor.compress_json(self.pdf_mid_data)
......
...@@ -9,17 +9,20 @@ from magic_pdf.user_api import parse_ocr_pdf ...@@ -9,17 +9,20 @@ from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False): def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
super().__init__(pdf_bytes, model_list, image_writer, is_debug) start_page_id=0, end_page_id=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True) self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug) self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -10,17 +10,20 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -10,17 +10,20 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False): def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
super().__init__(pdf_bytes, model_list, image_writer, is_debug) start_page_id=0, end_page_id=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False) self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug) self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -13,9 +13,10 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf ...@@ -13,9 +13,10 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False): def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None):
self.pdf_type = jso_useful_key["_pdf_type"] self.pdf_type = jso_useful_key["_pdf_type"]
super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug) super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -26,17 +27,21 @@ class UNIPipe(AbsPipe): ...@@ -26,17 +27,21 @@ class UNIPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False) self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True) self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty) is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
is_debug=self.is_debug) is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
......
...@@ -49,11 +49,26 @@ without method specified, auto will be used by default.""", ...@@ -49,11 +49,26 @@ without method specified, auto will be used by default.""",
'--debug', '--debug',
'debug_able', 'debug_able',
type=bool, type=bool,
help=('Enables detailed debugging information during' help='Enables detailed debugging information during the execution of the CLI commands.',
'the execution of the CLI commands.', ),
default=False, default=False,
) )
def cli(path, output_dir, method, debug_able): @click.option(
'-s',
'--start',
'start_page_id',
type=int,
help='The starting page for PDF parsing, beginning from 0.',
default=0,
)
@click.option(
'-e',
'--end',
'end_page_id',
type=int,
help='The ending page for PDF parsing, beginning from 0.',
default=None,
)
def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
model_config.__use_inside_model__ = True model_config.__use_inside_model__ = True
model_config.__model_mode__ = 'full' model_config.__model_mode__ = 'full'
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
...@@ -73,6 +88,8 @@ def cli(path, output_dir, method, debug_able): ...@@ -73,6 +88,8 @@ def cli(path, output_dir, method, debug_able):
[], [],
method, method,
debug_able, debug_able,
start_page_id=start_page_id,
end_page_id=end_page_id,
) )
except Exception as e: except Exception as e:
......
...@@ -42,6 +42,8 @@ def do_parse( ...@@ -42,6 +42,8 @@ def do_parse(
f_dump_content_list=False, f_dump_content_list=False,
f_make_md_mode=MakeMode.MM_MD, f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=False, f_draw_model_bbox=False,
start_page_id=0,
end_page_id=None,
): ):
if debug_able: if debug_able:
logger.warning("debug mode is on") logger.warning("debug mode is on")
...@@ -58,11 +60,14 @@ def do_parse( ...@@ -58,11 +60,14 @@ def do_parse(
if parse_method == 'auto': if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list} jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True) pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id)
elif parse_method == 'txt': elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True) pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id)
elif parse_method == 'ocr': elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True) pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
start_page_id=start_page_id, end_page_id=end_page_id)
else: else:
logger.error('unknown parse method') logger.error('unknown parse method')
exit(1) exit(1)
......
...@@ -25,8 +25,9 @@ PARSE_TYPE_TXT = "txt" ...@@ -25,8 +25,9 @@ PARSE_TYPE_TXT = "txt"
PARSE_TYPE_OCR = "ocr" PARSE_TYPE_OCR = "ocr"
def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
**kwargs): start_page_id=0, end_page_id=None,
*args, **kwargs):
""" """
解析文本类pdf 解析文本类pdf
""" """
...@@ -34,7 +35,8 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit ...@@ -34,7 +35,8 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
pdf_bytes, pdf_bytes,
pdf_models, pdf_models,
imageWriter, imageWriter,
start_page_id=start_page, start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug, debug_mode=is_debug,
) )
...@@ -45,8 +47,9 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit ...@@ -45,8 +47,9 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
return pdf_info_dict return pdf_info_dict
def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args, def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
**kwargs): start_page_id=0, end_page_id=None,
*args, **kwargs):
""" """
解析ocr类pdf 解析ocr类pdf
""" """
...@@ -54,7 +57,8 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit ...@@ -54,7 +57,8 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
pdf_bytes, pdf_bytes,
pdf_models, pdf_models,
imageWriter, imageWriter,
start_page_id=start_page, start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug, debug_mode=is_debug,
) )
...@@ -65,8 +69,9 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit ...@@ -65,8 +69,9 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
return pdf_info_dict return pdf_info_dict
def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
input_model_is_empty: bool = False, input_model_is_empty: bool = False,
start_page_id=0, end_page_id=None,
*args, **kwargs): *args, **kwargs):
""" """
ocr和文本混合的pdf,全部解析出来 ocr和文本混合的pdf,全部解析出来
...@@ -78,7 +83,8 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr ...@@ -78,7 +83,8 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
pdf_bytes, pdf_bytes,
pdf_models, pdf_models,
imageWriter, imageWriter,
start_page_id=start_page, start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=is_debug, debug_mode=is_debug,
) )
except Exception as e: except Exception as e:
...@@ -89,7 +95,9 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr ...@@ -89,7 +95,9 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False): if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr") logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
if input_model_is_empty: if input_model_is_empty:
pdf_models = doc_analyze(pdf_bytes, ocr=True) pdf_models = doc_analyze(pdf_bytes, ocr=True,
start_page_id=start_page_id,
end_page_id=end_page_id)
pdf_info_dict = parse_pdf(parse_pdf_by_ocr) pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None: if pdf_info_dict is None:
raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.") raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment