infer.py 1.4 KB
Newer Older
zhaoying1's avatar
zhaoying1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from typing import TYPE_CHECKING, Dict

import gradio as gr

from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box

if TYPE_CHECKING:
    from gradio.components import Component


def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
    with gr.Row():
        load_btn = gr.Button()
        unload_btn = gr.Button()

    info_box = gr.Textbox(show_label=False, interactive=False)

19
    chat_model = WebChatModel(lazy_init=True)
zhaoying1's avatar
zhaoying1 committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)

    load_btn.click(
        chat_model.load_model,
        [
            top_elems["lang"],
            top_elems["model_name"],
            top_elems["checkpoints"],
            top_elems["finetuning_type"],
            top_elems["quantization_bit"],
            top_elems["template"],
            top_elems["system_prompt"]
        ],
        [info_box]
    ).then(
        lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
    )

    unload_btn.click(
        chat_model.unload_model, [top_elems["lang"]], [info_box]
    ).then(
        lambda: ([], []), outputs=[chatbot, history]
    ).then(
        lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
    )

    return dict(
        info_box=info_box,
        load_btn=load_btn,
        unload_btn=unload_btn,
        **chat_elems
    )