"vscode:/vscode.git/clone" did not exist on "c8547ecddd8ebf5095b8ee3b825166b5cf94ad89"
utils.py 4.86 KB
Newer Older
zhaoying1's avatar
zhaoying1 committed
1
2
3
4
5
import os
import json
import gradio as gr
import matplotlib.figure
import matplotlib.pyplot as plt
6
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
zhaoying1's avatar
zhaoying1 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from datetime import datetime

from llmtuner.extras.ploting import smooth
from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS

if TYPE_CHECKING:
    from llmtuner.extras.callbacks import LogCallback


def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
    if not callback.max_steps:
        return gr.update(visible=False)

    percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
    label = "Running {:d}/{:d}: {} < {}".format(
        callback.cur_steps,
        callback.max_steps,
        callback.elapsed_time,
        callback.remaining_time
    )
    return gr.update(label=label, value=percentage, visible=True)


def get_time() -> str:
    return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')


def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
    with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
        dataset_info = json.load(f)
39

zhaoying1's avatar
zhaoying1 committed
40
41
42
43
44
45
46
47
48
49
    if (
        len(dataset) > 0
        and "file_name" in dataset_info[dataset[0]]
        and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
    ):
        return gr.update(interactive=True)
    else:
        return gr.update(interactive=False)


50
51
52
def get_preview(
    dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
) -> Tuple[int, list, Dict[str, Any]]:
zhaoying1's avatar
zhaoying1 committed
53
54
    with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
        dataset_info = json.load(f)
55
56
57
58

    data_file: str = dataset_info[dataset[0]]["file_name"]
    with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
        if data_file.endswith(".json"):
zhaoying1's avatar
zhaoying1 committed
59
            data = json.load(f)
60
61
62
63
64
        elif data_file.endswith(".jsonl"):
            data = [json.loads(line) for line in f]
        else:
            data = [line for line in f]
    return len(data), data[start:end], gr.update(visible=True)
zhaoying1's avatar
zhaoying1 committed
65
66
67
68


def can_quantize(finetuning_type: str) -> Dict[str, Any]:
    if finetuning_type != "lora":
69
        return gr.update(value="None", interactive=False)
zhaoying1's avatar
zhaoying1 committed
70
71
72
73
74
75
76
    else:
        return gr.update(interactive=True)


def gen_cmd(args: Dict[str, Any]) -> str:
    if args.get("do_train", None):
        args["plot_loss"] = True
77
    cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
zhaoying1's avatar
zhaoying1 committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    for k, v in args.items():
        if v is not None and v != "":
            cmd_lines.append("    --{} {} ".format(k, str(v)))
    cmd_text = "\\\n".join(cmd_lines)
    cmd_text = "```bash\n{}\n```".format(cmd_text)
    return cmd_text


def get_eval_results(path: os.PathLike) -> str:
    with open(path, "r", encoding="utf-8") as f:
        result = json.dumps(json.load(f), indent=4)
    return "```json\n{}\n```\n".format(result)


def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
93
    log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
zhaoying1's avatar
zhaoying1 committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    if not os.path.isfile(log_file):
        return None

    plt.close("all")
    fig = plt.figure()
    ax = fig.add_subplot(111)
    steps, losses = [], []
    with open(log_file, "r", encoding="utf-8") as f:
        for line in f:
            log_info = json.loads(line)
            if log_info.get("loss", None):
                steps.append(log_info["current_steps"])
                losses.append(log_info["loss"])

    if len(losses) == 0:
        return None

    ax.plot(steps, losses, alpha=0.4, label="original")
    ax.plot(steps, smooth(losses), label="smoothed")
    ax.legend()
    ax.set_xlabel("step")
    ax.set_ylabel("loss")
    return fig


def save_model(
    lang: str,
    model_name: str,
    checkpoints: List[str],
    finetuning_type: str,
    template: str,
    max_shard_size: int,
    save_dir: str
) -> Generator[str, None, None]:
    if not model_name:
        yield ALERTS["err_no_model"][lang]
        return

    model_name_or_path = get_model_path(model_name)
    if not model_name_or_path:
        yield ALERTS["err_no_path"][lang]
        return

    if not checkpoints:
        yield ALERTS["err_no_checkpoint"][lang]
        return

    checkpoint_dir = ",".join(
142
            [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]
zhaoying1's avatar
zhaoying1 committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        )

    if not save_dir:
        yield ALERTS["err_no_save_dir"][lang]
        return

    args = dict(
        model_name_or_path=model_name_or_path,
        checkpoint_dir=checkpoint_dir,
        finetuning_type=finetuning_type,
        template=template,
        output_dir=save_dir
    )

    yield ALERTS["info_exporting"][lang]
    export_model(args, max_shard_size="{}GB".format(max_shard_size))
    yield ALERTS["info_exported"][lang]