common.py 5.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import json
import os
from collections import defaultdict
from typing import Any, Dict, Optional

from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from yaml import safe_dump, safe_load

from ..extras.constants import (
    DATA_CONFIG,
    DEFAULT_MODULE,
    DEFAULT_TEMPLATE,
    PEFT_METHODS,
    STAGES_USE_PAIR_DATA,
    SUPPORTED_MODELS,
    TRAINING_STAGES,
    VISION_MODELS,
    DownloadSource,
)
from ..extras.logging import get_logger
from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available


if is_gradio_available():
    import gradio as gr


logger = get_logger(__name__)


ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}
DEFAULT_CACHE_DIR = "cache"
DEFAULT_CONFIG_DIR = "config"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user_config.yaml"


def get_save_dir(*paths: str) -> os.PathLike:
    paths = (path.replace(os.path.sep, "").replace(" ", "").strip() for path in paths)
    return os.path.join(DEFAULT_SAVE_DIR, *paths)


def get_config_path() -> os.PathLike:
    return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)


def get_save_path(config_path: str) -> os.PathLike:
    return os.path.join(DEFAULT_CONFIG_DIR, config_path)


def load_config() -> Dict[str, Any]:
    try:
        with open(get_config_path(), "r", encoding="utf-8") as f:
            return safe_load(f)
    except Exception:
        return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}


def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
    os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
    user_config = load_config()
    user_config["lang"] = lang or user_config["lang"]
    if model_name:
        user_config["last_model"] = model_name
        user_config["path_dict"][model_name] = model_path
    with open(get_config_path(), "w", encoding="utf-8") as f:
        safe_dump(user_config, f)


def load_args(config_path: str) -> Optional[Dict[str, Any]]:
    try:
        with open(get_save_path(config_path), "r", encoding="utf-8") as f:
            return safe_load(f)
    except Exception:
        return None


def save_args(config_path: str, config_dict: Dict[str, Any]) -> str:
    os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
    with open(get_save_path(config_path), "w", encoding="utf-8") as f:
        safe_dump(config_dict, f)

    return str(get_save_path(config_path))


def get_model_path(model_name: str) -> str:
    user_config = load_config()
    path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
    model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
    if (
        use_modelscope()
        and path_dict.get(DownloadSource.MODELSCOPE)
        and model_path == path_dict.get(DownloadSource.DEFAULT)
    ):  # replace path
        model_path = path_dict.get(DownloadSource.MODELSCOPE)
    return model_path


def get_prefix(model_name: str) -> str:
    return model_name.split("-")[0]


def get_module(model_name: str) -> str:
    return DEFAULT_MODULE.get(get_prefix(model_name), "q_proj,v_proj")


def get_template(model_name: str) -> str:
    if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE:
        return DEFAULT_TEMPLATE[get_prefix(model_name)]
    return "default"


def get_visual(model_name: str) -> bool:
    return get_prefix(model_name) in VISION_MODELS


def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
    if finetuning_type not in PEFT_METHODS:
        return gr.Dropdown(value=[], choices=[], interactive=False)

    adapters = []
    if model_name and finetuning_type == "lora":
        save_dir = get_save_dir(model_name, finetuning_type)
        if save_dir and os.path.isdir(save_dir):
            for adapter in os.listdir(save_dir):
                if os.path.isdir(os.path.join(save_dir, adapter)) and any(
                    os.path.isfile(os.path.join(save_dir, adapter, name)) for name in ADAPTER_NAMES
                ):
                    adapters.append(adapter)
    return gr.Dropdown(value=[], choices=adapters, interactive=True)


def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
    if dataset_dir == "ONLINE":
        logger.info("dataset_dir is ONLINE, using online dataset.")
        return {}

    try:
        with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception as err:
        logger.warning("Cannot open {} due to {}.".format(os.path.join(dataset_dir, DATA_CONFIG), str(err)))
        return {}


def list_dataset(dataset_dir: str = None, training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Dropdown":
    dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
    ranking = TRAINING_STAGES[training_stage] in STAGES_USE_PAIR_DATA
    datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
    return gr.Dropdown(value=[], choices=datasets)


def autoset_packing(training_stage: str = list(TRAINING_STAGES.keys())[0]) -> "gr.Button":
    return gr.Button(value=(TRAINING_STAGES[training_stage] == "pt"))