config_reader.py 4.27 KB
Newer Older
1
"""根据bucket的名字返回对应的s3 AK, SK,endpoint三元组."""
许瑞's avatar
许瑞 committed
2

3
4
5
6
7
import json
import os

from loguru import logger

8
from magic_pdf.config.constants import MODEL_NAME
9
10
from magic_pdf.libs.commons import parse_bucket_key

11
# 定义配置文件名常量
12
CONFIG_FILE_NAME = os.getenv('MINERU_TOOLS_CONFIG_JSON', 'magic-pdf.json')
13

kernel.h@qq.com's avatar
kernel.h@qq.com committed
14

许瑞's avatar
许瑞 committed
15
def read_config():
16
17
18
19
20
    if os.path.isabs(CONFIG_FILE_NAME):
        config_file = CONFIG_FILE_NAME
    else:
        home_dir = os.path.expanduser('~')
        config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
21
22

    if not os.path.exists(config_file):
23
        raise FileNotFoundError(f'{config_file} not found')
24

25
    with open(config_file, 'r', encoding='utf-8') as f:
26
        config = json.load(f)
许瑞's avatar
许瑞 committed
27
28
29
30
    return config


def get_s3_config(bucket_name: str):
31
    """~/magic-pdf.json 读出来."""
许瑞's avatar
许瑞 committed
32
    config = read_config()
33

34
    bucket_info = config.get('bucket_info')
赵小蒙's avatar
赵小蒙 committed
35
    if bucket_name not in bucket_info:
36
        access_key, secret_key, storage_endpoint = bucket_info['[default]']
37
38
    else:
        access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
39

赵小蒙's avatar
赵小蒙 committed
40
    if access_key is None or secret_key is None or storage_endpoint is None:
41
        raise Exception(f'ak, sk or endpoint not found in {CONFIG_FILE_NAME}')
42

赵小蒙's avatar
赵小蒙 committed
43
    # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
44

赵小蒙's avatar
赵小蒙 committed
45
    return access_key, secret_key, storage_endpoint
46
47


48
49
def get_s3_config_dict(path: str):
    access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
50
    return {'ak': access_key, 'sk': secret_key, 'endpoint': storage_endpoint}
51
52
53
54
55
56
57


def get_bucket_name(path):
    bucket, key = parse_bucket_key(path)
    return bucket


58
59
def get_local_models_dir():
    config = read_config()
60
    models_dir = config.get('models-dir')
61
    if models_dir is None:
62
        logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
63
        return '/tmp/models'
64
65
    else:
        return models_dir
66
67


68
69
def get_local_layoutreader_model_dir():
    config = read_config()
70
    layoutreader_model_dir = config.get('layoutreader-model-dir')
71
    if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
72
73
        home_dir = os.path.expanduser('~')
        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
74
75
76
77
78
79
        logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
        return layoutreader_at_modelscope_dir_path
    else:
        return layoutreader_model_dir


80
81
def get_device():
    config = read_config()
82
    device = config.get('device-mode')
83
    if device is None:
84
        logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
85
        return 'cpu'
86
87
    else:
        return device
88

89

90
91
def get_table_recog_config():
    config = read_config()
92
    table_config = config.get('table-config')
93
94
    if table_config is None:
        logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
95
        return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
96
97
    else:
        return table_config
98

99

100
101
def get_layout_config():
    config = read_config()
102
    layout_config = config.get('layout-config')
103
104
105
106
107
108
109
110
111
    if layout_config is None:
        logger.warning(f"'layout-config' not found in {CONFIG_FILE_NAME}, use '{MODEL_NAME.LAYOUTLMv3}' as default")
        return json.loads(f'{{"model": "{MODEL_NAME.LAYOUTLMv3}"}}')
    else:
        return layout_config


def get_formula_config():
    config = read_config()
112
    formula_config = config.get('formula-config')
113
114
115
116
117
118
    if formula_config is None:
        logger.warning(f"'formula-config' not found in {CONFIG_FILE_NAME}, use 'True' as default")
        return json.loads(f'{{"mfd_model": "{MODEL_NAME.YOLO_V8_MFD}","mfr_model": "{MODEL_NAME.UniMerNet_v2_Small}","enable": true}}')
    else:
        return formula_config

119
120
121
122
123
124
125
126
127
def get_llm_aided_config():
    config = read_config()
    llm_aided_config = config.get('llm-aided-config')
    if llm_aided_config is None:
        logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
        return None
    else:
        return llm_aided_config

128

129
130
if __name__ == '__main__':
    ak, sk, endpoint = get_s3_config('llm-raw')