chatter.py 8.93 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
import json
import os
chenych's avatar
chenych committed
17
from collections.abc import Generator
chenych's avatar
chenych committed
18
from contextlib import contextmanager
chenych's avatar
chenych committed
19
from typing import TYPE_CHECKING, Any, Optional
chenych's avatar
chenych committed
20
21

from transformers.utils import is_torch_npu_available
chenych's avatar
chenych committed
22

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
23
24
from ..chat import ChatModel
from ..data import Role
chenych's avatar
chenych committed
25
from ..extras.constants import PEFT_METHODS
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26
27
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
chenych's avatar
chenych committed
28
from .common import get_save_dir, load_config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
30
31
32
33
34
35
36
37
38
39
40
from .locales import ALERTS


if TYPE_CHECKING:
    from ..chat import BaseEngine
    from .manager import Manager


if is_gradio_available():
    import gradio as gr


chenych's avatar
chenych committed
41
def _escape_html(text: str) -> str:
chenych's avatar
chenych committed
42
    r"""Escape HTML characters."""
chenych's avatar
chenych committed
43
44
45
    return text.replace("<", "&lt;").replace(">", "&gt;")


chenych's avatar
chenych committed
46
47
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
    r"""Post-process the response text.
chenych's avatar
chenych committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
    """
    if thought_words[0] not in text:
        return _escape_html(text) if escape_html else text

    text = text.replace(thought_words[0], "")
    result = text.split(thought_words[1], maxsplit=1)
    if len(result) == 1:
        summary = ALERTS["info_thinking"][lang]
        thought, answer = text, ""
    else:
        summary = ALERTS["info_thought"][lang]
        thought, answer = result

    if escape_html:
        thought, answer = _escape_html(thought), _escape_html(answer)

    return (
        f"<details open><summary class='thinking-summary'><span>{summary}</span></summary>\n\n"
        f"<div class='thinking-container'>\n{thought}\n</div>\n</details>{answer}"
    )


chenych's avatar
chenych committed
72
73
74
75
76
77
78
79
@contextmanager
def update_attr(obj: Any, name: str, value: Any):
    old_value = getattr(obj, name, None)
    setattr(obj, name, value)
    yield
    setattr(obj, name, old_value)


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
80
81
82
83
class WebChatModel(ChatModel):
    def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
        self.manager = manager
        self.demo_mode = demo_mode
chenych's avatar
chenych committed
84
        self.engine: Optional[BaseEngine] = None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
85
86
87
88

        if not lazy_init:  # read arguments from command line
            super().__init__()

chenych's avatar
chenych committed
89
90
91
92
        if demo_mode and os.getenv("DEMO_MODEL") and os.getenv("DEMO_TEMPLATE"):  # load demo model
            model_name_or_path = os.getenv("DEMO_MODEL")
            template = os.getenv("DEMO_TEMPLATE")
            infer_backend = os.getenv("DEMO_BACKEND", "huggingface")
chenych's avatar
chenych committed
93
94
95
            super().__init__(
                dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
            )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
96
97
98
99
100
101
102

    @property
    def loaded(self) -> bool:
        return self.engine is not None

    def load_model(self, data) -> Generator[str, None, None]:
        get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
chenych's avatar
chenych committed
103
104
        lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
        finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
chenych's avatar
chenych committed
105
106
        user_config = load_config()

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
107
108
109
        error = ""
        if self.loaded:
            error = ALERTS["err_exists"][lang]
chenych's avatar
chenych committed
110
        elif not model_name:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
111
            error = ALERTS["err_no_model"][lang]
chenych's avatar
chenych committed
112
        elif not model_path:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
113
114
115
116
            error = ALERTS["err_no_path"][lang]
        elif self.demo_mode:
            error = ALERTS["err_demo"][lang]

chenych's avatar
chenych committed
117
118
119
120
121
        try:
            json.loads(get("infer.extra_args"))
        except json.JSONDecodeError:
            error = ALERTS["err_json_schema"][lang]

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
122
123
124
125
126
127
128
        if error:
            gr.Warning(error)
            yield error
            return

        yield ALERTS["info_loading"][lang]
        args = dict(
chenych's avatar
chenych committed
129
            model_name_or_path=model_path,
chenych's avatar
chenych committed
130
            cache_dir=user_config.get("cache_dir", None),
chenych's avatar
chenych committed
131
            finetuning_type=finetuning_type,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
132
            template=get("top.template"),
chenych's avatar
chenych committed
133
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
chenych's avatar
chenych committed
134
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
135
            use_unsloth=(get("top.booster") == "unsloth"),
chenych's avatar
chenych committed
136
            enable_liger_kernel=(get("top.booster") == "liger_kernel"),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
137
            infer_backend=get("infer.infer_backend"),
chenych's avatar
chenych committed
138
            infer_dtype=get("infer.infer_dtype"),
luopl's avatar
luopl committed
139
            trust_remote_code=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
140
        )
chenych's avatar
chenych committed
141
        args.update(json.loads(get("infer.extra_args")))
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
142

chenych's avatar
chenych committed
143
        # checkpoints
chenych's avatar
chenych committed
144
145
146
147
148
149
150
151
        if checkpoint_path:
            if finetuning_type in PEFT_METHODS:  # list
                args["adapter_name_or_path"] = ",".join(
                    [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
                )
            else:  # str
                args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)

chenych's avatar
chenych committed
152
153
154
155
156
157
        # quantization
        if get("top.quantization_bit") != "none":
            args["quantization_bit"] = int(get("top.quantization_bit"))
            args["quantization_method"] = get("top.quantization_method")
            args["double_quantization"] = not is_torch_npu_available()

chenych's avatar
chenych committed
158
        super().__init__(args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        yield ALERTS["info_loaded"][lang]

    def unload_model(self, data) -> Generator[str, None, None]:
        lang = data[self.manager.get_elem_by_id("top.lang")]

        if self.demo_mode:
            gr.Warning(ALERTS["err_demo"][lang])
            yield ALERTS["err_demo"][lang]
            return

        yield ALERTS["info_unloading"][lang]
        self.engine = None
        torch_gc()
        yield ALERTS["info_unloaded"][lang]

chenych's avatar
chenych committed
174
    @staticmethod
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
175
    def append(
chenych's avatar
chenych committed
176
177
        chatbot: list[dict[str, str]],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
178
179
        role: str,
        query: str,
chenych's avatar
chenych committed
180
        escape_html: bool,
chenych's avatar
chenych committed
181
182
    ) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
        r"""Add the user input to chatbot.
chenych's avatar
chenych committed
183
184
185
186
187
188
189
190
191

        Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
        Output: infer.chatbot, infer.messages, infer.query
        """
        return (
            chatbot + [{"role": "user", "content": _escape_html(query) if escape_html else query}],
            messages + [{"role": role, "content": query}],
            "",
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
192
193
194

    def stream(
        self,
chenych's avatar
chenych committed
195
196
        chatbot: list[dict[str, str]],
        messages: list[dict[str, str]],
chenych's avatar
chenych committed
197
        lang: str,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
198
199
        system: str,
        tools: str,
luopl's avatar
luopl committed
200
201
        image: Optional[Any],
        video: Optional[Any],
chenych's avatar
chenych committed
202
        audio: Optional[Any],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
203
204
205
        max_new_tokens: int,
        top_p: float,
        temperature: float,
chenych's avatar
chenych committed
206
207
        skip_special_tokens: bool,
        escape_html: bool,
chenych's avatar
chenych committed
208
        enable_thinking: bool,
chenych's avatar
chenych committed
209
210
    ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
        r"""Generate output text in stream.
chenych's avatar
chenych committed
211
212
213
214

        Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
        Output: infer.chatbot, infer.messages
        """
chenych's avatar
chenych committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        with update_attr(self.engine.template, "enable_thinking", enable_thinking):
            chatbot.append({"role": "assistant", "content": ""})
            response = ""
            for new_text in self.stream_chat(
                messages,
                system,
                tools,
                images=[image] if image else None,
                videos=[video] if video else None,
                audios=[audio] if audio else None,
                max_new_tokens=max_new_tokens,
                top_p=top_p,
                temperature=temperature,
                skip_special_tokens=skip_special_tokens,
            ):
                response += new_text
                if tools:
                    result = self.engine.template.extract_tool(response)
                else:
                    result = response

                if isinstance(result, list):
                    tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
                    tool_calls = json.dumps(tool_calls, ensure_ascii=False)
                    output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
                    bot_text = "```json\n" + tool_calls + "\n```"
                else:
                    output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
                    bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)

                chatbot[-1] = {"role": "assistant", "content": bot_text}
                yield chatbot, output_messages