chatbot.py 4.21 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
import json
chenych's avatar
chenych committed
16
17
18
19
from typing import TYPE_CHECKING, Dict, Tuple

from ...data import Role
from ...extras.packages import is_gradio_available
chenych's avatar
chenych committed
20
from ..locales import ALERTS
chenych's avatar
chenych committed
21
22
23
24
25
26
27
28
29
30
31
32


if is_gradio_available():
    import gradio as gr


if TYPE_CHECKING:
    from gradio.components import Component

    from ..engine import Engine


chenych's avatar
chenych committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def check_json_schema(text: str, lang: str) -> None:
    r"""
    Checks if the json schema is valid.
    """
    try:
        tools = json.loads(text)
        if tools:
            assert isinstance(tools, list)
            for tool in tools:
                if "name" not in tool:
                    raise NotImplementedError("Name not found.")
    except NotImplementedError:
        gr.Warning(ALERTS["err_tool_name"][lang])
    except Exception:
        gr.Warning(ALERTS["err_json_schema"][lang])


chenych's avatar
chenych committed
50
51
52
def create_chat_box(
    engine: "Engine", visible: bool = False
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
chenych's avatar
chenych committed
53
    lang = engine.manager.get_elem_by_id("top.lang")
chenych's avatar
chenych committed
54
    with gr.Column(visible=visible) as chat_box:
chenych's avatar
chenych committed
55
        chatbot = gr.Chatbot(type="messages", show_copy_button=True)
chenych's avatar
chenych committed
56
57
58
59
60
61
62
63
64
        messages = gr.State([])
        with gr.Row():
            with gr.Column(scale=4):
                with gr.Row():
                    with gr.Column():
                        role = gr.Dropdown(choices=[Role.USER.value, Role.OBSERVATION.value], value=Role.USER.value)
                        system = gr.Textbox(show_label=False)
                        tools = gr.Textbox(show_label=False, lines=3)

luopl's avatar
luopl committed
65
66
                    with gr.Column() as mm_box:
                        with gr.Tab("Image"):
chenych's avatar
chenych committed
67
                            image = gr.Image(type="pil")
luopl's avatar
luopl committed
68
69

                        with gr.Tab("Video"):
chenych's avatar
chenych committed
70
71
72
73
                            video = gr.Video()

                        with gr.Tab("Audio"):
                            audio = gr.Audio(type="filepath")
chenych's avatar
chenych committed
74
75
76
77
78

                query = gr.Textbox(show_label=False, lines=8)
                submit_btn = gr.Button(variant="primary")

            with gr.Column(scale=1):
chenych's avatar
chenych committed
79
                max_new_tokens = gr.Slider(minimum=8, maximum=8192, value=1024, step=1)
chenych's avatar
chenych committed
80
81
                top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
                temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
chenych's avatar
chenych committed
82
83
                skip_special_tokens = gr.Checkbox(value=True)
                escape_html = gr.Checkbox(value=True)
chenych's avatar
chenych committed
84
85
86
87
88
89
                clear_btn = gr.Button()

    tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])

    submit_btn.click(
        engine.chatter.append,
chenych's avatar
chenych committed
90
        [chatbot, messages, role, query, escape_html],
chenych's avatar
chenych committed
91
92
93
        [chatbot, messages, query],
    ).then(
        engine.chatter.stream,
chenych's avatar
chenych committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        [
            chatbot,
            messages,
            lang,
            system,
            tools,
            image,
            video,
            audio,
            max_new_tokens,
            top_p,
            temperature,
            skip_special_tokens,
            escape_html,
        ],
chenych's avatar
chenych committed
109
110
111
112
113
114
115
116
117
118
119
120
        [chatbot, messages],
    )
    clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])

    return (
        chatbot,
        messages,
        dict(
            chat_box=chat_box,
            role=role,
            system=system,
            tools=tools,
luopl's avatar
luopl committed
121
            mm_box=mm_box,
chenych's avatar
chenych committed
122
            image=image,
luopl's avatar
luopl committed
123
            video=video,
chenych's avatar
chenych committed
124
            audio=audio,
chenych's avatar
chenych committed
125
126
127
128
129
            query=query,
            submit_btn=submit_btn,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            temperature=temperature,
chenych's avatar
chenych committed
130
131
            skip_special_tokens=skip_special_tokens,
            escape_html=escape_html,
chenych's avatar
chenych committed
132
133
134
            clear_btn=clear_btn,
        ),
    )