chatter.py 8.34 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
import json
import os
chenych's avatar
chenych committed
17
18
19
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple

from transformers.utils import is_torch_npu_available
chenych's avatar
chenych committed
20

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
22
from ..chat import ChatModel
from ..data import Role
chenych's avatar
chenych committed
23
from ..extras.constants import PEFT_METHODS
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
chenych's avatar
chenych committed
26
from .common import get_save_dir, load_config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
27
28
29
30
31
32
33
34
35
36
37
38
from .locales import ALERTS


if TYPE_CHECKING:
    from ..chat import BaseEngine
    from .manager import Manager


if is_gradio_available():
    import gradio as gr


chenych's avatar
chenych committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def _escape_html(text: str) -> str:
    r"""
    Escapes HTML characters.
    """
    return text.replace("<", "&lt;").replace(">", "&gt;")


def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
    r"""
    Post-processes the response text.

    Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
    """
    if thought_words[0] not in text:
        return _escape_html(text) if escape_html else text

    text = text.replace(thought_words[0], "")
    result = text.split(thought_words[1], maxsplit=1)
    if len(result) == 1:
        summary = ALERTS["info_thinking"][lang]
        thought, answer = text, ""
    else:
        summary = ALERTS["info_thought"][lang]
        thought, answer = result

    if escape_html:
        thought, answer = _escape_html(thought), _escape_html(answer)

    return (
        f"<details open><summary class='thinking-summary'><span>{summary}</span></summary>\n\n"
        f"<div class='thinking-container'>\n{thought}\n</div>\n</details>{answer}"
    )


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
73
74
75
76
77
78
79
80
81
82
83
84
class WebChatModel(ChatModel):
    def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
        self.manager = manager
        self.demo_mode = demo_mode
        self.engine: Optional["BaseEngine"] = None

        if not lazy_init:  # read arguments from command line
            super().__init__()

        if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"):  # load demo model
            model_name_or_path = os.environ.get("DEMO_MODEL")
            template = os.environ.get("DEMO_TEMPLATE")
chenych's avatar
chenych committed
85
86
87
88
            infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
            super().__init__(
                dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
            )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
89
90
91
92
93
94
95

    @property
    def loaded(self) -> bool:
        return self.engine is not None

    def load_model(self, data) -> Generator[str, None, None]:
        get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
chenych's avatar
chenych committed
96
97
        lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
        finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
chenych's avatar
chenych committed
98
99
        user_config = load_config()

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
100
101
102
        error = ""
        if self.loaded:
            error = ALERTS["err_exists"][lang]
chenych's avatar
chenych committed
103
        elif not model_name:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
104
            error = ALERTS["err_no_model"][lang]
chenych's avatar
chenych committed
105
        elif not model_path:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
107
108
109
110
111
112
113
114
115
116
            error = ALERTS["err_no_path"][lang]
        elif self.demo_mode:
            error = ALERTS["err_demo"][lang]

        if error:
            gr.Warning(error)
            yield error
            return

        yield ALERTS["info_loading"][lang]
        args = dict(
chenych's avatar
chenych committed
117
            model_name_or_path=model_path,
chenych's avatar
chenych committed
118
            cache_dir=user_config.get("cache_dir", None),
chenych's avatar
chenych committed
119
            finetuning_type=finetuning_type,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
120
            template=get("top.template"),
chenych's avatar
chenych committed
121
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
chenych's avatar
chenych committed
122
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
123
            use_unsloth=(get("top.booster") == "unsloth"),
chenych's avatar
chenych committed
124
            enable_liger_kernel=(get("top.booster") == "liger_kernel"),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
125
            infer_backend=get("infer.infer_backend"),
chenych's avatar
chenych committed
126
            infer_dtype=get("infer.infer_dtype"),
luopl's avatar
luopl committed
127
            trust_remote_code=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
128
129
        )

chenych's avatar
chenych committed
130
        # checkpoints
chenych's avatar
chenych committed
131
132
133
134
135
136
137
138
        if checkpoint_path:
            if finetuning_type in PEFT_METHODS:  # list
                args["adapter_name_or_path"] = ",".join(
                    [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
                )
            else:  # str
                args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)

chenych's avatar
chenych committed
139
140
141
142
143
144
        # quantization
        if get("top.quantization_bit") != "none":
            args["quantization_bit"] = int(get("top.quantization_bit"))
            args["quantization_method"] = get("top.quantization_method")
            args["double_quantization"] = not is_torch_npu_available()

chenych's avatar
chenych committed
145
        super().__init__(args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        yield ALERTS["info_loaded"][lang]

    def unload_model(self, data) -> Generator[str, None, None]:
        lang = data[self.manager.get_elem_by_id("top.lang")]

        if self.demo_mode:
            gr.Warning(ALERTS["err_demo"][lang])
            yield ALERTS["err_demo"][lang]
            return

        yield ALERTS["info_unloading"][lang]
        self.engine = None
        torch_gc()
        yield ALERTS["info_unloaded"][lang]

chenych's avatar
chenych committed
161
    @staticmethod
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
162
    def append(
chenych's avatar
chenych committed
163
164
        chatbot: List[Dict[str, str]],
        messages: List[Dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
165
166
        role: str,
        query: str,
chenych's avatar
chenych committed
167
168
169
170
171
172
173
174
175
176
177
178
179
        escape_html: bool,
    ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
        r"""
        Adds the user input to chatbot.

        Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
        Output: infer.chatbot, infer.messages, infer.query
        """
        return (
            chatbot + [{"role": "user", "content": _escape_html(query) if escape_html else query}],
            messages + [{"role": role, "content": query}],
            "",
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
180
181
182

    def stream(
        self,
chenych's avatar
chenych committed
183
184
185
        chatbot: List[Dict[str, str]],
        messages: List[Dict[str, str]],
        lang: str,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
186
187
        system: str,
        tools: str,
luopl's avatar
luopl committed
188
189
        image: Optional[Any],
        video: Optional[Any],
chenych's avatar
chenych committed
190
        audio: Optional[Any],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
191
192
193
        max_new_tokens: int,
        top_p: float,
        temperature: float,
chenych's avatar
chenych committed
194
195
196
197
198
199
200
201
202
203
        skip_special_tokens: bool,
        escape_html: bool,
    ) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
        r"""
        Generates output text in stream.

        Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
        Output: infer.chatbot, infer.messages
        """
        chatbot.append({"role": "assistant", "content": ""})
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
204
205
        response = ""
        for new_text in self.stream_chat(
luopl's avatar
luopl committed
206
207
208
209
210
            messages,
            system,
            tools,
            images=[image] if image else None,
            videos=[video] if video else None,
chenych's avatar
chenych committed
211
            audios=[audio] if audio else None,
luopl's avatar
luopl committed
212
213
214
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            temperature=temperature,
chenych's avatar
chenych committed
215
            skip_special_tokens=skip_special_tokens,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
216
217
218
        ):
            response += new_text
            if tools:
chenych's avatar
chenych committed
219
                result = self.engine.template.extract_tool(response)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
220
221
222
            else:
                result = response

chenych's avatar
chenych committed
223
            if isinstance(result, list):
luopl's avatar
luopl committed
224
                tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
chenych's avatar
chenych committed
225
                tool_calls = json.dumps(tool_calls, ensure_ascii=False)
chenych's avatar
chenych committed
226
227
                output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
                bot_text = "```json\n" + tool_calls + "\n```"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
228
229
            else:
                output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
chenych's avatar
chenych committed
230
                bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
231

chenych's avatar
chenych committed
232
            chatbot[-1] = {"role": "assistant", "content": bot_text}
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
233
            yield chatbot, output_messages