Unverified Commit a8831ba6 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1419 from myhloli/dev

feat:Add NPU support
parents ad9abc32 8a0aa7a4
...@@ -14,7 +14,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType ...@@ -14,7 +14,7 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
...@@ -91,6 +91,7 @@ def chars_to_content(span): ...@@ -91,6 +91,7 @@ def chars_to_content(span):
content = '' content = ''
for char in span['chars']: for char in span['chars']:
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格 # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
char1 = char char1 = char
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
...@@ -182,7 +183,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -182,7 +183,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
for block in text_blocks_raw: for block in text_blocks_raw:
for line in block['lines']: for line in block['lines']:
cosine, sine = line['dir'] cosine, sine = line['dir']
if abs (cosine) < 0.9 or abs(sine) > 0.1: if abs(cosine) < 0.9 or abs(sine) > 0.1:
continue continue
for span in line['spans']: for span in line['spans']:
all_pymu_chars.extend(span['chars']) all_pymu_chars.extend(span['chars'])
...@@ -280,13 +281,21 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -280,13 +281,21 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
def model_init(model_name: str): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
device = get_device()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device('cuda')
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
supports_bfloat16 = True supports_bfloat16 = True
else: else:
supports_bfloat16 = False supports_bfloat16 = False
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
device = torch.device('npu')
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
else: else:
device = torch.device('cpu') device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
...@@ -860,7 +869,7 @@ def pdf_parse_union( ...@@ -860,7 +869,7 @@ def pdf_parse_union(
'pdf_info': pdf_info_list, 'pdf_info': pdf_info_list,
} }
clean_memory() clean_memory(get_device())
return new_pdf_info_dict return new_pdf_info_dict
......
...@@ -5,6 +5,7 @@ from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text ...@@ -5,6 +5,7 @@ from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
from openai import OpenAI from openai import OpenAI
#@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容: formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
1. 修正渲染或编译错误: 1. 修正渲染或编译错误:
......
...@@ -9,6 +9,7 @@ from magic_pdf.config.enums import SupportedPdfParseMethod ...@@ -9,6 +9,7 @@ from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.libs.draw_bbox import draw_char_bbox
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.operators.models import InferenceResult from magic_pdf.operators.models import InferenceResult
...@@ -83,6 +84,7 @@ def do_parse( ...@@ -83,6 +84,7 @@ def do_parse(
f_make_md_mode=MakeMode.MM_MD, f_make_md_mode=MakeMode.MM_MD,
f_draw_model_bbox=False, f_draw_model_bbox=False,
f_draw_line_sort_bbox=False, f_draw_line_sort_bbox=False,
f_draw_char_bbox=False,
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
lang=None, lang=None,
...@@ -94,6 +96,7 @@ def do_parse( ...@@ -94,6 +96,7 @@ def do_parse(
logger.warning('debug mode is on') logger.warning('debug mode is on')
f_draw_model_bbox = True f_draw_model_bbox = True
f_draw_line_sort_bbox = True f_draw_line_sort_bbox = True
# f_draw_char_bbox = True
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf( pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes, start_page_id, end_page_id pdf_bytes, start_page_id, end_page_id
...@@ -205,6 +208,9 @@ def do_parse( ...@@ -205,6 +208,9 @@ def do_parse(
os.path.join(local_md_dir, f'{pdf_file_name}_line_sort.pdf') os.path.join(local_md_dir, f'{pdf_file_name}_line_sort.pdf')
) )
if f_draw_char_bbox:
draw_char_bbox(pdf_bytes, local_md_dir, f'{pdf_file_name}_char_bbox.pdf')
if f_dump_md: if f_dump_md:
pipe_result.dump_md( pipe_result.dump_md(
md_writer, md_writer,
......
...@@ -183,6 +183,7 @@ def to_pdf(file_path): ...@@ -183,6 +183,7 @@ def to_pdf(file_path):
return tmp_file_path return tmp_file_path
if __name__ == '__main__': if __name__ == '__main__':
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML(header) gr.HTML(header)
......
...@@ -4,7 +4,7 @@ click>=8.1.7 ...@@ -4,7 +4,7 @@ click>=8.1.7
fast-langdetect==0.2.0 fast-langdetect==0.2.0
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
pydantic>=2.7.2,<2.8.0 pydantic>=2.7.2
PyMuPDF>=1.24.9 PyMuPDF>=1.24.9
scikit-learn>=1.0.2 scikit-learn>=1.0.2
torch>=2.2.2 torch>=2.2.2
......
...@@ -50,6 +50,7 @@ if __name__ == '__main__': ...@@ -50,6 +50,7 @@ if __name__ == '__main__':
"accelerate", # struct-eqtable依赖 "accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo "doclayout_yolo==0.0.2", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle "rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime",
"rapid_table", # rapid_table "rapid_table", # rapid_table
"PyYAML", # yaml "PyYAML", # yaml
"openai", # openai SDK "openai", # openai SDK
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment