Commit b0e220c5 authored by myhloli's avatar myhloli
Browse files

refactor(demo): simplify batch_demo.py and update demo.py

- Remove unnecessary imports and code in batch_demo.py
- Update demo.py to use relative paths and improve code structure
- Adjust output directory structure in both scripts
- Remove redundant code and simplify functions
parent dbdb99f9
import os import os
import shutil
import tempfile
from pathlib import Path from pathlib import Path
import click
import fitz
from loguru import logger
import magic_pdf.model as model_config
from magic_pdf.data.batch_build_dataset import batch_build_dataset from magic_pdf.data.batch_build_dataset import batch_build_dataset
from magic_pdf.data.data_reader_writer import FileBasedDataReader from magic_pdf.tools.common import batch_do_parse
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.version import __version__
from magic_pdf.tools.common import batch_do_parse, do_parse, parse_pdf_methods
from magic_pdf.utils.office_to_pdf import convert_file_to_pdf
def batch(pdf_dir, output_dir, method, lang): def batch(pdf_dir, output_dir, method, lang):
model_config.__use_inside_model__ = True
model_config.__model_mode__ = 'full'
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
doc_paths = [] doc_paths = []
for doc_path in Path(pdf_dir).glob('*'): for doc_path in Path(pdf_dir).glob('*'):
if doc_path.suffix == '.pdf': if doc_path.suffix == '.pdf':
doc_paths.append(doc_path) doc_paths.append(doc_path)
# build dataset with 2 workers # build dataset with 2 workers
datasets = batch_build_dataset(doc_paths, 2, lang) datasets = batch_build_dataset(doc_paths, 4, lang)
os.environ["MINERU_MIN_BATCH_INFERENCE_SIZE"] = "10" # every 10 pages will be parsed in one batch # os.environ["MINERU_MIN_BATCH_INFERENCE_SIZE"] = "200" # every 200 pages will be parsed in one batch
batch_do_parse(output_dir, [str(doc_path.stem) for doc_path in doc_paths], datasets, method, True) batch_do_parse(output_dir, [str(doc_path.stem) for doc_path in doc_paths], datasets, method)
if __name__ == '__main__': if __name__ == '__main__':
batch("batch_data", "output", "ocr", "en") batch("pdfs", "output", "auto", "")
...@@ -7,18 +7,17 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze ...@@ -7,18 +7,17 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
# args # args
pdf_file_name = "demo1.pdf" # replace with the real pdf path __dir__ = os.path.dirname(os.path.abspath(__file__))
name_without_suff = pdf_file_name.split(".")[0] pdf_file_name = os.path.join(__dir__, "pdfs", "demo1.pdf") # replace with the real pdf path
name_without_extension = os.path.basename(pdf_file_name).split('.')[0]
# prepare env # prepare env
local_image_dir, local_md_dir = "output/images", "output" local_image_dir = os.path.join(__dir__, "output", name_without_extension, "images")
local_md_dir = os.path.join(__dir__, "output", name_without_extension)
image_dir = str(os.path.basename(local_image_dir)) image_dir = str(os.path.basename(local_image_dir))
os.makedirs(local_image_dir, exist_ok=True) os.makedirs(local_image_dir, exist_ok=True)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter( image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
local_md_dir
)
# read bytes # read bytes
reader1 = FileBasedDataReader("") reader1 = FileBasedDataReader("")
...@@ -41,32 +40,29 @@ else: ...@@ -41,32 +40,29 @@ else:
## pipeline ## pipeline
pipe_result = infer_result.pipe_txt_mode(image_writer) pipe_result = infer_result.pipe_txt_mode(image_writer)
### draw model result on each page
infer_result.draw_model(os.path.join(local_md_dir, f"{name_without_suff}_model.pdf"))
### get model inference result ### get model inference result
model_inference_result = infer_result.get_infer_res() model_inference_result = infer_result.get_infer_res()
### draw layout result on each page ### draw layout result on each page
pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_suff}_layout.pdf")) pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_extension}_layout.pdf"))
### draw spans result on each page ### draw spans result on each page
pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_suff}_spans.pdf")) pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_extension}_spans.pdf"))
### get markdown content ### get markdown content
md_content = pipe_result.get_markdown(image_dir) md_content = pipe_result.get_markdown(image_dir)
### dump markdown ### dump markdown
pipe_result.dump_md(md_writer, f"{name_without_suff}.md", image_dir) pipe_result.dump_md(md_writer, f"{name_without_extension}.md", image_dir)
### get content list content ### get content list content
content_list_content = pipe_result.get_content_list(image_dir) content_list_content = pipe_result.get_content_list(image_dir)
### dump content list ### dump content list
pipe_result.dump_content_list(md_writer, f"{name_without_suff}_content_list.json", image_dir) pipe_result.dump_content_list(md_writer, f"{name_without_extension}_content_list.json", image_dir)
### get middle json ### get middle json
middle_json_content = pipe_result.get_middle_json() middle_json_content = pipe_result.get_middle_json()
### dump middle json ### dump middle json
pipe_result.dump_middle_json(md_writer, f'{name_without_suff}_middle.json') pipe_result.dump_middle_json(md_writer, f'{name_without_extension}_middle.json')
...@@ -92,7 +92,7 @@ You can find the `magic-pdf.json` file in your user directory. ...@@ -92,7 +92,7 @@ You can find the `magic-pdf.json` file in your user directory.
Download a sample file from the repository and test it. Download a sample file from the repository and test it.
```sh ```sh
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf wget https://github.com/opendatalab/MinerU/raw/master/demo/pdfs/small_ocr.pdf
magic-pdf -p small_ocr.pdf -o ./output magic-pdf -p small_ocr.pdf -o ./output
``` ```
......
...@@ -91,7 +91,7 @@ pip install -U magic-pdf[full] -i https://mirrors.aliyun.com/pypi/simple ...@@ -91,7 +91,7 @@ pip install -U magic-pdf[full] -i https://mirrors.aliyun.com/pypi/simple
从仓库中下载样本文件,并测试 从仓库中下载样本文件,并测试
```bash ```bash
wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/demo/small_ocr.pdf wget https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/demo/pdfs/small_ocr.pdf
magic-pdf -p small_ocr.pdf -o ./output magic-pdf -p small_ocr.pdf -o ./output
``` ```
......
...@@ -53,7 +53,7 @@ You can find the `magic-pdf.json` file in your 【user directory】 . ...@@ -53,7 +53,7 @@ You can find the `magic-pdf.json` file in your 【user directory】 .
Download a sample file from the repository and test it. Download a sample file from the repository and test it.
```powershell ```powershell
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf wget https://github.com/opendatalab/MinerU/raw/master/demo/pdfs/small_ocr.pdf -O small_ocr.pdf
magic-pdf -p small_ocr.pdf -o ./output magic-pdf -p small_ocr.pdf -o ./output
``` ```
......
...@@ -54,7 +54,7 @@ pip install -U magic-pdf[full] -i https://mirrors.aliyun.com/pypi/simple ...@@ -54,7 +54,7 @@ pip install -U magic-pdf[full] -i https://mirrors.aliyun.com/pypi/simple
从仓库中下载样本文件,并测试 从仓库中下载样本文件,并测试
```powershell ```powershell
wget https://github.com/opendatalab/MinerU/raw/master/demo/small_ocr.pdf -O small_ocr.pdf wget https://github.com/opendatalab/MinerU/raw/master/demo/pdfs/small_ocr.pdf -O small_ocr.pdf
magic-pdf -p small_ocr.pdf -o ./output magic-pdf -p small_ocr.pdf -o ./output
``` ```
......
...@@ -3,21 +3,6 @@ from __future__ import division ...@@ -3,21 +3,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals from __future__ import unicode_literals
import os
import sys
import numpy as np
# import paddle
import signal
import random
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import copy
# from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler
# import paddle.distributed as dist
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -16,7 +16,6 @@ class TextSystem(object): ...@@ -16,7 +16,6 @@ class TextSystem(object):
if self.use_angle_cls: if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args, **kwargs) self.text_classifier = predict_cls.TextClassifier(args, **kwargs)
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
''' '''
img_height, img_width = img.shape[0:2] img_height, img_width = img.shape[0:2]
......
...@@ -90,8 +90,6 @@ without method specified, auto will be used by default.""", ...@@ -90,8 +90,6 @@ without method specified, auto will be used by default.""",
default=None, default=None,
) )
def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id): def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
model_config.__use_inside_model__ = True
model_config.__model_mode__ = 'full'
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
def read_fn(path: Path): def read_fn(path: Path):
......
...@@ -73,7 +73,7 @@ def _do_parse( ...@@ -73,7 +73,7 @@ def _do_parse(
pdf_bytes_or_dataset, pdf_bytes_or_dataset,
model_list, model_list,
parse_method, parse_method,
debug_able, debug_able=False,
f_draw_span_bbox=True, f_draw_span_bbox=True,
f_draw_layout_bbox=True, f_draw_layout_bbox=True,
f_dump_md=True, f_dump_md=True,
...@@ -250,7 +250,7 @@ def do_parse( ...@@ -250,7 +250,7 @@ def do_parse(
pdf_bytes_or_dataset, pdf_bytes_or_dataset,
model_list, model_list,
parse_method, parse_method,
debug_able, debug_able=False,
f_draw_span_bbox=True, f_draw_span_bbox=True,
f_draw_layout_bbox=True, f_draw_layout_bbox=True,
f_dump_md=True, f_dump_md=True,
...@@ -291,7 +291,7 @@ def batch_do_parse( ...@@ -291,7 +291,7 @@ def batch_do_parse(
pdf_file_names: list[str], pdf_file_names: list[str],
pdf_bytes_or_datasets: list[bytes | Dataset], pdf_bytes_or_datasets: list[bytes | Dataset],
parse_method, parse_method,
debug_able, debug_able=False,
f_draw_span_bbox=True, f_draw_span_bbox=True,
f_draw_layout_bbox=True, f_draw_layout_bbox=True,
f_dump_md=True, f_dump_md=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment