Unverified Commit cdf6e0cf authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2703 from myhloli/dev

Dev
parents c8904da6 ec3adde8
...@@ -2,15 +2,12 @@ import os ...@@ -2,15 +2,12 @@ import os
import time import time
from typing import List, Tuple from typing import List, Tuple
import PIL.Image import PIL.Image
import torch from loguru import logger
from .model_init import MineruPipelineModel from .model_init import MineruPipelineModel
from mineru.utils.config_reader import get_device from mineru.utils.config_reader import get_device
from ...utils.pdf_classify import classify from ...utils.pdf_classify import classify
from ...utils.pdf_image_tools import load_images_from_pdf from ...utils.pdf_image_tools import load_images_from_pdf
from loguru import logger
from ...utils.model_utils import get_vram, clean_memory from ...utils.model_utils import get_vram, clean_memory
...@@ -166,7 +163,7 @@ def batch_image_analyze( ...@@ -166,7 +163,7 @@ def batch_image_analyze(
try: try:
import torch_npu import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
torch.npu.set_compile_mode(jit_compile=False) torch_npu.npu.set_compile_mode(jit_compile=False)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(
"NPU is selected as device, but torch_npu is not available. " "NPU is selected as device, but torch_npu is not available. "
......
...@@ -8,7 +8,6 @@ from mineru.utils.pdf_image_tools import load_images_from_pdf ...@@ -8,7 +8,6 @@ from mineru.utils.pdf_image_tools import load_images_from_pdf
from .base_predictor import BasePredictor from .base_predictor import BasePredictor
from .predictor import get_predictor from .predictor import get_predictor
from .token_to_middle_json import result_to_middle_json from .token_to_middle_json import result_to_middle_json
from ...utils.enum_class import ModelPath
from ...utils.models_download_utils import auto_download_and_get_model_root_path from ...utils.models_download_utils import auto_download_and_get_model_root_path
......
...@@ -7,7 +7,7 @@ from loguru import logger ...@@ -7,7 +7,7 @@ from loguru import logger
from mineru.utils.config_reader import get_device from mineru.utils.config_reader import get_device
from mineru.utils.model_utils import get_vram from mineru.utils.model_utils import get_vram
from ..version import __version__ from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command() @click.command()
@click.version_option(__version__, @click.version_option(__version__,
...@@ -138,8 +138,7 @@ from ..version import __version__ ...@@ -138,8 +138,7 @@ from ..version import __version__
def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source): def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes if not backend.endswith('-client'):
def get_device_mode() -> str: def get_device_mode() -> str:
if device_mode is not None: if device_mode is not None:
return device_mode return device_mode
......
...@@ -8,15 +8,12 @@ from pathlib import Path ...@@ -8,15 +8,12 @@ from pathlib import Path
import pypdfium2 as pdfium import pypdfium2 as pdfium
from loguru import logger from loguru import logger
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
from mineru.data.data_reader_writer import FileBasedDataWriter from mineru.data.data_reader_writer import FileBasedDataWriter
from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
from mineru.utils.enum_class import MakeMode from mineru.utils.enum_class import MakeMode
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
pdf_suffixes = [".pdf"] pdf_suffixes = [".pdf"]
image_suffixes = [".png", ".jpeg", ".jpg"] image_suffixes = [".png", ".jpeg", ".jpg"]
...@@ -99,6 +96,11 @@ def do_parse( ...@@ -99,6 +96,11 @@ def do_parse(
): ):
if backend == "pipeline": if backend == "pipeline":
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
for idx, pdf_bytes in enumerate(pdf_bytes_list): for idx, pdf_bytes in enumerate(pdf_bytes_list):
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id) new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
pdf_bytes_list[idx] = new_pdf_bytes pdf_bytes_list[idx] = new_pdf_bytes
...@@ -163,6 +165,7 @@ def do_parse( ...@@ -163,6 +165,7 @@ def do_parse(
logger.info(f"local output dir is {local_md_dir}") logger.info(f"local output dir is {local_md_dir}")
else: else:
if backend.startswith("vlm-"): if backend.startswith("vlm-"):
backend = backend[4:] backend = backend[4:]
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import json import json
import os import os
from loguru import logger from loguru import logger
try:
import torch
import torch_npu
except ImportError:
pass
# 定义配置文件名常量 # 定义配置文件名常量
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json') CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'mineru.json')
...@@ -71,15 +77,12 @@ def get_device(): ...@@ -71,15 +77,12 @@ def get_device():
if device_mode is not None: if device_mode is not None:
return device_mode return device_mode
else: else:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
return "mps" return "mps"
else: else:
try: try:
import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
return "npu" return "npu"
except Exception as e: except Exception as e:
......
...@@ -6,6 +6,12 @@ import numpy as np ...@@ -6,6 +6,12 @@ import numpy as np
from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
try:
import torch
import torch_npu
except ImportError:
pass
def crop_img(input_res, input_img, crop_paste_x=0, crop_paste_y=0): def crop_img(input_res, input_img, crop_paste_x=0, crop_paste_y=0):
...@@ -297,14 +303,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol ...@@ -297,14 +303,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
def clean_memory(device='cuda'): def clean_memory(device='cuda'):
import torch
if device == 'cuda': if device == 'cuda':
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
elif str(device).startswith("npu"): elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
torch_npu.npu.empty_cache() torch_npu.npu.empty_cache()
elif str(device).startswith("mps"): elif str(device).startswith("mps"):
...@@ -322,13 +325,10 @@ def clean_vram(device, vram_threshold=8): ...@@ -322,13 +325,10 @@ def clean_vram(device, vram_threshold=8):
def get_vram(device): def get_vram(device):
import torch
if torch.cuda.is_available() and str(device).startswith("cuda"): if torch.cuda.is_available() and str(device).startswith("cuda"):
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory return total_memory
elif str(device).startswith("npu"): elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory return total_memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment