Commit 98b8c4a9 authored by myhloli's avatar myhloli
Browse files

refactor: streamline formula and table enable configurations in the pipeline

parent cf5c8f47
...@@ -5,7 +5,7 @@ import PIL.Image ...@@ -5,7 +5,7 @@ import PIL.Image
import torch import torch
from .model_init import MineruPipelineModel from .model_init import MineruPipelineModel
from mineru.utils.config_reader import get_device, get_formula_config, get_table_recog_config from mineru.utils.config_reader import get_device
from ...utils.pdf_classify import classify from ...utils.pdf_classify import classify
from ...utils.pdf_image_tools import load_images_from_pdf from ...utils.pdf_image_tools import load_images_from_pdf
...@@ -44,20 +44,15 @@ class ModelSingleton: ...@@ -44,20 +44,15 @@ class ModelSingleton:
def custom_model_init( def custom_model_init(
lang=None, lang=None,
formula_enable=None, formula_enable=True,
table_enable=None, table_enable=True,
): ):
model_init_start = time.time() model_init_start = time.time()
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
device = get_device() device = get_device()
formula_config = get_formula_config() formula_config = {"enable": formula_enable}
if formula_enable is not None: table_config = {"enable": table_enable}
formula_config['enable'] = formula_enable
table_config = get_table_recog_config()
if table_enable is not None:
table_config['enable'] = table_enable
model_input = { model_input = {
'device': device, 'device': device,
...@@ -78,8 +73,8 @@ def doc_analyze( ...@@ -78,8 +73,8 @@ def doc_analyze(
pdf_bytes_list, pdf_bytes_list,
lang_list, lang_list,
parse_method: str = 'auto', parse_method: str = 'auto',
formula_enable=None, formula_enable=True,
table_enable=None, table_enable=True,
): ):
MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100)) MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
...@@ -152,8 +147,8 @@ def doc_analyze( ...@@ -152,8 +147,8 @@ def doc_analyze(
def batch_image_analyze( def batch_image_analyze(
images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]], images_with_extra_info: List[Tuple[PIL.Image.Image, bool, str]],
formula_enable=None, formula_enable=True,
table_enable=None): table_enable=True):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx) # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
from .batch_analyze import BatchAnalyze from .batch_analyze import BatchAnalyze
...@@ -194,6 +189,10 @@ def batch_image_analyze( ...@@ -194,6 +189,10 @@ def batch_image_analyze(
batch_ratio = 1 batch_ratio = 1
logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}') logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
if os.getenv('MINERU_FORMULA_ENABLE', None) is not None:
formula_enable = os.getenv('MINERU_FORMULA_ENABLE').lower() == 'true'
if os.getenv('MINERU_TABLE_ENABLE', None) is not None:
table_enable = os.getenv('MINERU_TABLE_ENABLE').lower() == 'true'
batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable) batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
results = batch_model(images_with_extra_info) results = batch_model(images_with_extra_info)
......
...@@ -140,10 +140,6 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -140,10 +140,6 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source): def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
if os.getenv('MINERU_FORMULA_ENABLE', None) is None:
os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable).lower()
if os.getenv('MINERU_TABLE_ENABLE', None) is None:
os.environ['MINERU_TABLE_ENABLE'] = str(table_enable).lower()
def get_device_mode() -> str: def get_device_mode() -> str:
if device_mode is not None: if device_mode is not None:
return device_mode return device_mode
...@@ -184,6 +180,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i ...@@ -184,6 +180,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
p_lang_list=lang_list, p_lang_list=lang_list,
backend=backend, backend=backend,
parse_method=method, parse_method=method,
p_formula_enable=formula_enable,
p_table_enable=table_enable,
server_url=server_url, server_url=server_url,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id end_page_id=end_page_id
......
...@@ -115,6 +115,9 @@ def do_parse( ...@@ -115,6 +115,9 @@ def do_parse(
pdf_doc = all_pdf_docs[idx] pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx] _lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx] _ocr_enable = ocr_enabled_list[idx]
if os.getenv('MINERU_FORMULA_ENABLE', None) is not None:
p_formula_enable = os.getenv('MINERU_FORMULA_ENABLE').lower() == 'true'
middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable) middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable)
pdf_info = middle_json["pdf_info"] pdf_info = middle_json["pdf_info"]
......
...@@ -86,24 +86,6 @@ def get_device(): ...@@ -86,24 +86,6 @@ def get_device():
return "cpu" return "cpu"
def get_table_recog_config():
table_enable = os.getenv('MINERU_TABLE_ENABLE', None)
if table_enable is not None:
return json.loads(f'{{"enable": {table_enable}}}')
else:
# logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}')
def get_formula_config():
formula_enable = os.getenv('MINERU_FORMULA_ENABLE', None)
if formula_enable is not None:
return json.loads(f'{{"enable": {formula_enable}}}')
else:
# logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}')
def get_latex_delimiter_config(): def get_latex_delimiter_config():
config = read_config() config = read_config()
if config is None: if config is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment