config_reader.py 2.11 KB
Newer Older
kernel.h@qq.com's avatar
kernel.h@qq.com committed
1
2
3
4
"""
根据bucket的名字返回对应的s3 AK, SK,endpoint三元组

"""
许瑞's avatar
许瑞 committed
5

6
7
8
9
10
import json
import os

from loguru import logger

11
12
from magic_pdf.libs.commons import parse_bucket_key

13
14
15
# 定义配置文件名常量
CONFIG_FILE_NAME = "magic-pdf.json"

kernel.h@qq.com's avatar
kernel.h@qq.com committed
16

许瑞's avatar
许瑞 committed
17
def read_config():
赵小蒙's avatar
赵小蒙 committed
18
19
    home_dir = os.path.expanduser("~")

20
    config_file = os.path.join(home_dir, CONFIG_FILE_NAME)
21
22

    if not os.path.exists(config_file):
23
        raise FileNotFoundError(f"{config_file} not found")
24

25
    with open(config_file, "r", encoding="utf-8") as f:
26
        config = json.load(f)
许瑞's avatar
许瑞 committed
27
28
29
30
31
32
33
34
    return config


def get_s3_config(bucket_name: str):
    """
    ~/magic-pdf.json 读出来
    """
    config = read_config()
35

赵小蒙's avatar
赵小蒙 committed
36
37
    bucket_info = config.get("bucket_info")
    if bucket_name not in bucket_info:
38
39
40
        access_key, secret_key, storage_endpoint = bucket_info["[default]"]
    else:
        access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
41

赵小蒙's avatar
赵小蒙 committed
42
    if access_key is None or secret_key is None or storage_endpoint is None:
43
        raise Exception(f"ak, sk or endpoint not found in {CONFIG_FILE_NAME}")
44

赵小蒙's avatar
赵小蒙 committed
45
    # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
46

赵小蒙's avatar
赵小蒙 committed
47
    return access_key, secret_key, storage_endpoint
48
49


50
51
52
53
54
55
56
57
58
59
def get_s3_config_dict(path: str):
    access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
    return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}


def get_bucket_name(path):
    bucket, key = parse_bucket_key(path)
    return bucket


60
61
def get_local_models_dir():
    config = read_config()
62
63
    models_dir = config.get("models-dir")
    if models_dir is None:
64
        logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
65
66
67
        return "/tmp/models"
    else:
        return models_dir
68
69
70
71


def get_device():
    config = read_config()
72
73
    device = config.get("device-mode")
    if device is None:
74
        logger.warning(f"'device-mode' not found in {CONFIG_FILE_NAME}, use 'cpu' as default")
75
76
77
        return "cpu"
    else:
        return device
78
79


许瑞's avatar
许瑞 committed
80
if __name__ == "__main__":
81
    ak, sk, endpoint = get_s3_config("llm-raw")