runner.py 20.3 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

luopl's avatar
luopl committed
15
import json
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
16
import os
chenych's avatar
chenych committed
17
18
19
from copy import deepcopy
from subprocess import Popen, TimeoutExpired
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21

from transformers.trainer import TRAINING_ARGS_NAME
luopl's avatar
luopl committed
22
from transformers.utils import is_torch_npu_available
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
23

chenych's avatar
chenych committed
24
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
luopl's avatar
luopl committed
25
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
luopl's avatar
luopl committed
26
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
chenych's avatar
chenych committed
27
28
29
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
from .locales import ALERTS, LOCALES
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


if is_gradio_available():
    import gradio as gr


if TYPE_CHECKING:
    from gradio.components import Component

    from .manager import Manager


class Runner:
    def __init__(self, manager: "Manager", demo_mode: bool = False) -> None:
        self.manager = manager
        self.demo_mode = demo_mode
        """ Resume """
chenych's avatar
chenych committed
47
        self.trainer: Optional["Popen"] = None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
48
49
50
51
52
53
54
55
        self.do_train = True
        self.running_data: Dict["Component", Any] = None
        """ State """
        self.aborted = False
        self.running = False

    def set_abort(self) -> None:
        self.aborted = True
chenych's avatar
chenych committed
56
57
        if self.trainer is not None:
            abort_process(self.trainer.pid)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    def _initialize(self, data: Dict["Component", Any], do_train: bool, from_preview: bool) -> str:
        get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
        lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
        dataset = get("train.dataset") if do_train else get("eval.dataset")

        if self.running:
            return ALERTS["err_conflict"][lang]

        if not model_name:
            return ALERTS["err_no_model"][lang]

        if not model_path:
            return ALERTS["err_no_path"][lang]

chenych's avatar
chenych committed
73
        if not dataset:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
74
75
76
77
78
79
            return ALERTS["err_no_dataset"][lang]

        if not from_preview and self.demo_mode:
            return ALERTS["err_demo"][lang]

        if do_train:
chenych's avatar
chenych committed
80
81
82
            if not get("train.output_dir"):
                return ALERTS["err_no_output_dir"][lang]

luopl's avatar
luopl committed
83
84
85
86
87
            try:
                json.loads(get("train.extra_args"))
            except json.JSONDecodeError:
                return ALERTS["err_json_schema"][lang]

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
88
            stage = TRAINING_STAGES[get("train.training_stage")]
chenych's avatar
chenych committed
89
            if stage == "ppo" and not get("train.reward_model"):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
90
                return ALERTS["err_no_reward_model"][lang]
chenych's avatar
chenych committed
91
92
93
        else:
            if not get("eval.output_dir"):
                return ALERTS["err_no_output_dir"][lang]
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
94

chenych's avatar
chenych committed
95
        if not from_preview and not is_gpu_or_npu_available():
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
96
97
98
99
100
101
            gr.Warning(ALERTS["warn_no_cuda"][lang])

        return ""

    def _finalize(self, lang: str, finish_info: str) -> str:
        finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
luopl's avatar
luopl committed
102
        gr.Info(finish_info)
chenych's avatar
chenych committed
103
        self.trainer = None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
104
105
106
107
108
109
110
111
        self.aborted = False
        self.running = False
        self.running_data = None
        torch_gc()
        return finish_info

    def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
        get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
chenych's avatar
chenych committed
112
        model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
113
114
115
116
117
118
119
        user_config = load_config()

        args = dict(
            stage=TRAINING_STAGES[get("train.training_stage")],
            do_train=True,
            model_name_or_path=get("top.model_path"),
            cache_dir=user_config.get("cache_dir", None),
chenych's avatar
chenych committed
120
121
            preprocessing_num_workers=16,
            finetuning_type=finetuning_type,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
122
123
            template=get("top.template"),
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
chenych's avatar
chenych committed
124
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
125
            use_unsloth=(get("top.booster") == "unsloth"),
luopl's avatar
luopl committed
126
            enable_liger_kernel=(get("top.booster") == "liger_kernel"),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            dataset_dir=get("train.dataset_dir"),
            dataset=",".join(get("train.dataset")),
            cutoff_len=get("train.cutoff_len"),
            learning_rate=float(get("train.learning_rate")),
            num_train_epochs=float(get("train.num_train_epochs")),
            max_samples=int(get("train.max_samples")),
            per_device_train_batch_size=get("train.batch_size"),
            gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
            lr_scheduler_type=get("train.lr_scheduler_type"),
            max_grad_norm=float(get("train.max_grad_norm")),
            logging_steps=get("train.logging_steps"),
            save_steps=get("train.save_steps"),
            warmup_steps=get("train.warmup_steps"),
            neftune_noise_alpha=get("train.neftune_alpha") or None,
chenych's avatar
chenych committed
141
142
143
144
            packing=get("train.packing") or get("train.neat_packing"),
            neat_packing=get("train.neat_packing"),
            train_on_prompt=get("train.train_on_prompt"),
            mask_history=get("train.mask_history"),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
145
146
147
148
149
            resize_vocab=get("train.resize_vocab"),
            use_llama_pro=get("train.use_llama_pro"),
            shift_attn=get("train.shift_attn"),
            report_to="all" if get("train.report_to") else "none",
            use_galore=get("train.use_galore"),
luopl's avatar
luopl committed
150
            use_apollo=get("train.use_apollo"),
chenych's avatar
chenych committed
151
            use_badam=get("train.use_badam"),
luopl's avatar
luopl committed
152
            use_swanlab=get("train.use_swanlab"),
chenych's avatar
chenych committed
153
            output_dir=get_save_dir(model_name, finetuning_type, get("train.output_dir")),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
154
155
156
            fp16=(get("train.compute_type") == "fp16"),
            bf16=(get("train.compute_type") == "bf16"),
            pure_bf16=(get("train.compute_type") == "pure_bf16"),
chenych's avatar
chenych committed
157
            plot_loss=True,
luopl's avatar
luopl committed
158
            trust_remote_code=True,
chenych's avatar
chenych committed
159
            ddp_timeout=180000000,
luopl's avatar
luopl committed
160
            include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True,  # FIXME
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
161
        )
luopl's avatar
luopl committed
162
        args.update(json.loads(get("train.extra_args")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
163

chenych's avatar
chenych committed
164
165
166
167
168
169
170
171
172
173
174
175
176
        # checkpoints
        if get("top.checkpoint_path"):
            if finetuning_type in PEFT_METHODS:  # list
                args["adapter_name_or_path"] = ",".join(
                    [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
                )
            else:  # str
                args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))

        # quantization
        if get("top.quantization_bit") in QUANTIZATION_BITS:
            args["quantization_bit"] = int(get("top.quantization_bit"))
            args["quantization_method"] = get("top.quantization_method")
luopl's avatar
luopl committed
177
            args["double_quantization"] = not is_torch_npu_available()
chenych's avatar
chenych committed
178
179

        # freeze config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
180
        if args["finetuning_type"] == "freeze":
chenych's avatar
chenych committed
181
182
183
184
185
186
            args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
            args["freeze_trainable_modules"] = get("train.freeze_trainable_modules")
            args["freeze_extra_modules"] = get("train.freeze_extra_modules") or None

        # lora config
        if args["finetuning_type"] == "lora":
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
187
188
189
190
191
192
193
            args["lora_rank"] = get("train.lora_rank")
            args["lora_alpha"] = get("train.lora_alpha")
            args["lora_dropout"] = get("train.lora_dropout")
            args["loraplus_lr_ratio"] = get("train.loraplus_lr_ratio") or None
            args["create_new_adapter"] = get("train.create_new_adapter")
            args["use_rslora"] = get("train.use_rslora")
            args["use_dora"] = get("train.use_dora")
chenych's avatar
chenych committed
194
195
196
            args["pissa_init"] = get("train.use_pissa")
            args["pissa_convert"] = get("train.use_pissa")
            args["lora_target"] = get("train.lora_target") or "all"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
197
198
199
            args["additional_target"] = get("train.additional_target") or None

            if args["use_llama_pro"]:
chenych's avatar
chenych committed
200
                args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
201

chenych's avatar
chenych committed
202
        # rlhf config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
203
        if args["stage"] == "ppo":
chenych's avatar
chenych committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
            if finetuning_type in PEFT_METHODS:
                args["reward_model"] = ",".join(
                    [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("train.reward_model")]
                )
            else:
                args["reward_model"] = get_save_dir(model_name, finetuning_type, get("train.reward_model"))

            args["reward_model_type"] = "lora" if finetuning_type == "lora" else "full"
            args["ppo_score_norm"] = get("train.ppo_score_norm")
            args["ppo_whiten_rewards"] = get("train.ppo_whiten_rewards")
            args["top_k"] = 0
            args["top_p"] = 0.9
        elif args["stage"] in ["dpo", "kto"]:
            args["pref_beta"] = get("train.pref_beta")
            args["pref_ftx"] = get("train.pref_ftx")
            args["pref_loss"] = get("train.pref_loss")

        # galore config
        if args["use_galore"]:
            args["galore_rank"] = get("train.galore_rank")
            args["galore_update_interval"] = get("train.galore_update_interval")
            args["galore_scale"] = get("train.galore_scale")
            args["galore_target"] = get("train.galore_target")

luopl's avatar
luopl committed
228
229
230
231
232
233
234
        # apollo config
        if args["use_apollo"]:
            args["apollo_rank"] = get("train.apollo_rank")
            args["apollo_update_interval"] = get("train.apollo_update_interval")
            args["apollo_scale"] = get("train.apollo_scale")
            args["apollo_target"] = get("train.apollo_target")

chenych's avatar
chenych committed
235
236
237
238
239
240
        # badam config
        if args["use_badam"]:
            args["badam_mode"] = get("train.badam_mode")
            args["badam_switch_mode"] = get("train.badam_switch_mode")
            args["badam_switch_interval"] = get("train.badam_switch_interval")
            args["badam_update_ratio"] = get("train.badam_update_ratio")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
241

luopl's avatar
luopl committed
242
243
244
245
246
247
248
249
        # swanlab config
        if get("train.use_swanlab"):
            args["swanlab_project"] = get("train.swanlab_project")
            args["swanlab_run_name"] = get("train.swanlab_run_name")
            args["swanlab_workspace"] = get("train.swanlab_workspace")
            args["swanlab_api_key"] = get("train.swanlab_api_key")
            args["swanlab_mode"] = get("train.swanlab_mode")

chenych's avatar
chenych committed
250
        # eval config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
251
252
        if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
            args["val_size"] = get("train.val_size")
chenych's avatar
chenych committed
253
            args["eval_strategy"] = "steps"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
254
255
256
            args["eval_steps"] = args["save_steps"]
            args["per_device_eval_batch_size"] = args["per_device_train_batch_size"]

chenych's avatar
chenych committed
257
258
259
260
        # ds config
        if get("train.ds_stage") != "none":
            ds_stage = get("train.ds_stage")
            ds_offload = "offload_" if get("train.ds_offload") else ""
luopl's avatar
luopl committed
261
            args["deepspeed"] = os.path.join(DEFAULT_CACHE_DIR, f"ds_z{ds_stage}_{ds_offload}config.json")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
262
263
264
265
266

        return args

    def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
        get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
chenych's avatar
chenych committed
267
        model_name, finetuning_type = get("top.model_name"), get("top.finetuning_type")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
268
269
270
271
272
273
        user_config = load_config()

        args = dict(
            stage="sft",
            model_name_or_path=get("top.model_path"),
            cache_dir=user_config.get("cache_dir", None),
chenych's avatar
chenych committed
274
275
276
            preprocessing_num_workers=16,
            finetuning_type=finetuning_type,
            quantization_method=get("top.quantization_method"),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
277
278
            template=get("top.template"),
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
chenych's avatar
chenych committed
279
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
280
281
            use_unsloth=(get("top.booster") == "unsloth"),
            dataset_dir=get("eval.dataset_dir"),
chenych's avatar
chenych committed
282
            eval_dataset=",".join(get("eval.dataset")),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
283
284
285
286
287
288
289
            cutoff_len=get("eval.cutoff_len"),
            max_samples=int(get("eval.max_samples")),
            per_device_eval_batch_size=get("eval.batch_size"),
            predict_with_generate=True,
            max_new_tokens=get("eval.max_new_tokens"),
            top_p=get("eval.top_p"),
            temperature=get("eval.temperature"),
chenych's avatar
chenych committed
290
            output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")),
luopl's avatar
luopl committed
291
            trust_remote_code=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
292
293
294
295
296
297
298
        )

        if get("eval.predict"):
            args["do_predict"] = True
        else:
            args["do_eval"] = True

chenych's avatar
chenych committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        # checkpoints
        if get("top.checkpoint_path"):
            if finetuning_type in PEFT_METHODS:  # list
                args["adapter_name_or_path"] = ",".join(
                    [get_save_dir(model_name, finetuning_type, adapter) for adapter in get("top.checkpoint_path")]
                )
            else:  # str
                args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, get("top.checkpoint_path"))

        # quantization
        if get("top.quantization_bit") in QUANTIZATION_BITS:
            args["quantization_bit"] = int(get("top.quantization_bit"))
            args["quantization_method"] = get("top.quantization_method")

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        return args

    def _preview(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", str], None, None]:
        output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
        error = self._initialize(data, do_train, from_preview=True)
        if error:
            gr.Warning(error)
            yield {output_box: error}
        else:
            args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
            yield {output_box: gen_cmd(args)}

    def _launch(self, data: Dict["Component", Any], do_train: bool) -> Generator[Dict["Component", Any], None, None]:
        output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if do_train else "eval"))
        error = self._initialize(data, do_train, from_preview=False)
        if error:
            gr.Warning(error)
            yield {output_box: error}
        else:
            self.do_train, self.running_data = do_train, data
chenych's avatar
chenych committed
333
334
335
336
337
338
339
340
341
342
343
            args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)

            os.makedirs(args["output_dir"], exist_ok=True)
            save_args(os.path.join(args["output_dir"], LLAMABOARD_CONFIG), self._form_config_dict(data))

            env = deepcopy(os.environ)
            env["LLAMABOARD_ENABLED"] = "1"
            env["LLAMABOARD_WORKDIR"] = args["output_dir"]
            if args.get("deepspeed", None) is not None:
                env["FORCE_TORCHRUN"] = "1"

luopl's avatar
luopl committed
344
            self.trainer = Popen(["llamafactory-cli", "train", save_cmd(args)], env=env)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
345
346
            yield from self.monitor()

chenych's avatar
chenych committed
347
348
349
350
351
352
353
354
355
356
    def _form_config_dict(self, data: Dict["Component", Any]) -> Dict[str, Any]:
        config_dict = {}
        skip_ids = ["top.lang", "top.model_path", "train.output_dir", "train.config_path"]
        for elem, value in data.items():
            elem_id = self.manager.get_id_by_elem(elem)
            if elem_id not in skip_ids:
                config_dict[elem_id] = value

        return config_dict

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    def preview_train(self, data):
        yield from self._preview(data, do_train=True)

    def preview_eval(self, data):
        yield from self._preview(data, do_train=False)

    def run_train(self, data):
        yield from self._launch(data, do_train=True)

    def run_eval(self, data):
        yield from self._launch(data, do_train=False)

    def monitor(self):
        self.aborted = False
        self.running = True

chenych's avatar
chenych committed
373
374
        get = lambda elem_id: self.running_data[self.manager.get_elem_by_id(elem_id)]
        lang, model_name, finetuning_type = get("top.lang"), get("top.model_name"), get("top.finetuning_type")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
375
376
377
378
        output_dir = get("{}.output_dir".format("train" if self.do_train else "eval"))
        output_path = get_save_dir(model_name, finetuning_type, output_dir)

        output_box = self.manager.get_elem_by_id("{}.output_box".format("train" if self.do_train else "eval"))
chenych's avatar
chenych committed
379
        progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
380
381
        loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None

luopl's avatar
luopl committed
382
        running_log = ""
chenych's avatar
chenych committed
383
        while self.trainer is not None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
384
385
386
            if self.aborted:
                yield {
                    output_box: ALERTS["info_aborting"][lang],
chenych's avatar
chenych committed
387
                    progress_bar: gr.Slider(visible=False),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
388
389
                }
            else:
chenych's avatar
chenych committed
390
                running_log, running_progress, running_loss = get_trainer_info(output_path, self.do_train)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
391
                return_dict = {
chenych's avatar
chenych committed
392
393
                    output_box: running_log,
                    progress_bar: running_progress,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
394
                }
chenych's avatar
chenych committed
395
396
                if running_loss is not None:
                    return_dict[loss_viewer] = running_loss
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
397
398
399

                yield return_dict

chenych's avatar
chenych committed
400
401
402
403
404
            try:
                self.trainer.wait(2)
                self.trainer = None
            except TimeoutExpired:
                continue
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
405
406

        if self.do_train:
luopl's avatar
luopl committed
407
            if os.path.exists(os.path.join(output_path, TRAINING_ARGS_NAME)) or use_ray():
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
408
409
410
411
                finish_info = ALERTS["info_finished"][lang]
            else:
                finish_info = ALERTS["err_failed"][lang]
        else:
luopl's avatar
luopl committed
412
            if os.path.exists(os.path.join(output_path, "all_results.json")) or use_ray():
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
413
414
415
416
417
                finish_info = get_eval_results(os.path.join(output_path, "all_results.json"))
            else:
                finish_info = ALERTS["err_failed"][lang]

        return_dict = {
luopl's avatar
luopl committed
418
            output_box: self._finalize(lang, finish_info) + "\n\n" + running_log,
chenych's avatar
chenych committed
419
            progress_bar: gr.Slider(visible=False),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
420
421
422
423
424
425
426
427
428
429
430
431
        }
        yield return_dict

    def save_args(self, data):
        output_box = self.manager.get_elem_by_id("train.output_box")
        error = self._initialize(data, do_train=True, from_preview=True)
        if error:
            gr.Warning(error)
            return {output_box: error}

        lang = data[self.manager.get_elem_by_id("top.lang")]
        config_path = data[self.manager.get_elem_by_id("train.config_path")]
chenych's avatar
chenych committed
432
433
        os.makedirs(DEFAULT_CONFIG_DIR, exist_ok=True)
        save_path = os.path.join(DEFAULT_CONFIG_DIR, config_path)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
434

chenych's avatar
chenych committed
435
        save_args(save_path, self._form_config_dict(data))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
436
437
438
439
        return {output_box: ALERTS["info_config_saved"][lang] + save_path}

    def load_args(self, lang: str, config_path: str):
        output_box = self.manager.get_elem_by_id("train.output_box")
chenych's avatar
chenych committed
440
        config_dict = load_args(os.path.join(DEFAULT_CONFIG_DIR, config_path))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
441
442
443
444
445
446
447
448
449
        if config_dict is None:
            gr.Warning(ALERTS["err_config_not_found"][lang])
            return {output_box: ALERTS["err_config_not_found"][lang]}

        output_dict: Dict["Component", Any] = {output_box: ALERTS["info_config_loaded"][lang]}
        for elem_id, value in config_dict.items():
            output_dict[self.manager.get_elem_by_id(elem_id)] = value

        return output_dict
chenych's avatar
chenych committed
450
451
452
453
454
455
456
457
458
459
460
461
462
463

    def check_output_dir(self, lang: str, model_name: str, finetuning_type: str, output_dir: str):
        output_box = self.manager.get_elem_by_id("train.output_box")
        output_dict: Dict["Component", Any] = {output_box: LOCALES["output_box"][lang]["value"]}
        if model_name and output_dir and os.path.isdir(get_save_dir(model_name, finetuning_type, output_dir)):
            gr.Warning(ALERTS["warn_output_dir_exists"][lang])
            output_dict[output_box] = ALERTS["warn_output_dir_exists"][lang]

            output_dir = get_save_dir(model_name, finetuning_type, output_dir)
            config_dict = load_args(os.path.join(output_dir, LLAMABOARD_CONFIG))  # load llamaboard config
            for elem_id, value in config_dict.items():
                output_dict[self.manager.get_elem_by_id(elem_id)] = value

        return output_dict