engine.py 3.4 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
from typing import TYPE_CHECKING, Any
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
16
17

from .chatter import WebChatModel
chenych's avatar
chenych committed
18
from .common import create_ds_config, get_time, load_config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
19
20
21
22
23
24
25
26
27
28
from .locales import LOCALES
from .manager import Manager
from .runner import Runner


if TYPE_CHECKING:
    from gradio.components import Component


class Engine:
chenych's avatar
chenych committed
29
    r"""A general engine to control the behaviors of Web UI."""
chenych's avatar
chenych committed
30

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
31
32
33
34
35
36
    def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
        self.demo_mode = demo_mode
        self.pure_chat = pure_chat
        self.manager = Manager()
        self.runner = Runner(self.manager, demo_mode)
        self.chatter = WebChatModel(self.manager, demo_mode, lazy_init=(not pure_chat))
chenych's avatar
chenych committed
37
38
        if not demo_mode:
            create_ds_config()
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
39

chenych's avatar
chenych committed
40
41
42
    def _update_component(self, input_dict: dict[str, dict[str, Any]]) -> dict["Component", "Component"]:
        r"""Update gradio components according to the (elem_id, properties) mapping."""
        output_dict: dict[Component, Component] = {}
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
43
44
45
46
47
48
49
        for elem_id, elem_attr in input_dict.items():
            elem = self.manager.get_elem_by_id(elem_id)
            output_dict[elem] = elem.__class__(**elem_attr)

        return output_dict

    def resume(self):
chenych's avatar
chenych committed
50
        r"""Get the initial value of gradio components and restores training status if necessary."""
chenych's avatar
chenych committed
51
        user_config = load_config() if not self.demo_mode else {}  # do not use config in demo mode
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
52
53
54
55
        lang = user_config.get("lang", None) or "en"
        init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}

        if not self.pure_chat:
chenych's avatar
chenych committed
56
57
            current_time = get_time()
            init_dict["train.current_time"] = {"value": current_time}
luopl's avatar
luopl committed
58
59
60
            init_dict["train.output_dir"] = {"value": f"train_{current_time}"}
            init_dict["train.config_path"] = {"value": f"{current_time}.yaml"}
            init_dict["eval.output_dir"] = {"value": f"eval_{current_time}"}
luopl's avatar
luopl committed
61
            init_dict["infer.mm_box"] = {"visible": False}
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
62
63
64
65
66
67

            if user_config.get("last_model", None):
                init_dict["top.model_name"] = {"value": user_config["last_model"]}

        yield self._update_component(init_dict)

chenych's avatar
chenych committed
68
        if self.runner.running and not self.demo_mode and not self.pure_chat:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
69
70
71
72
73
74
75
            yield {elem: elem.__class__(value=value) for elem, value in self.runner.running_data.items()}
            if self.runner.do_train:
                yield self._update_component({"train.resume_btn": {"value": True}})
            else:
                yield self._update_component({"eval.resume_btn": {"value": True}})

    def change_lang(self, lang: str):
chenych's avatar
chenych committed
76
        r"""Update the displayed language of gradio components."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
78
79
80
81
        return {
            elem: elem.__class__(**LOCALES[elem_name][lang])
            for elem_name, elem in self.manager.get_elem_iter()
            if elem_name in LOCALES
        }