chatter.py 8.35 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
15
16
import json
import os
chenych's avatar
chenych committed
17
18
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Optional
chenych's avatar
chenych committed
19
20

from transformers.utils import is_torch_npu_available
chenych's avatar
chenych committed
21

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
22
23
from ..chat import ChatModel
from ..data import Role
chenych's avatar
chenych committed
24
from ..extras.constants import PEFT_METHODS
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
25
26
from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
chenych's avatar
chenych committed
27
from .common import get_save_dir, load_config
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
28
29
30
31
32
33
34
35
36
37
38
39
from .locales import ALERTS


if TYPE_CHECKING:
    from ..chat import BaseEngine
    from .manager import Manager


if is_gradio_available():
    import gradio as gr


chenych's avatar
chenych committed
40
def _escape_html(text: str) -> str:
chenych's avatar
chenych committed
41
    r"""Escape HTML characters."""
chenych's avatar
chenych committed
42
43
44
    return text.replace("<", "&lt;").replace(">", "&gt;")


chenych's avatar
chenych committed
45
46
def _format_response(text: str, lang: str, escape_html: bool, thought_words: tuple[str, str]) -> str:
    r"""Post-process the response text.
chenych's avatar
chenych committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

    Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
    """
    if thought_words[0] not in text:
        return _escape_html(text) if escape_html else text

    text = text.replace(thought_words[0], "")
    result = text.split(thought_words[1], maxsplit=1)
    if len(result) == 1:
        summary = ALERTS["info_thinking"][lang]
        thought, answer = text, ""
    else:
        summary = ALERTS["info_thought"][lang]
        thought, answer = result

    if escape_html:
        thought, answer = _escape_html(thought), _escape_html(answer)

    return (
        f"<details open><summary class='thinking-summary'><span>{summary}</span></summary>\n\n"
        f"<div class='thinking-container'>\n{thought}\n</div>\n</details>{answer}"
    )


Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
73
74
class WebChatModel(ChatModel):
    def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
        self.manager = manager
        self.demo_mode = demo_mode
chenych's avatar
chenych committed
75
        self.engine: Optional[BaseEngine] = None
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
76
77
78
79
80
81
82

        if not lazy_init:  # read arguments from command line
            super().__init__()

        if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"):  # load demo model
            model_name_or_path = os.environ.get("DEMO_MODEL")
            template = os.environ.get("DEMO_TEMPLATE")
chenych's avatar
chenych committed
83
84
85
86
            infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
            super().__init__(
                dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
            )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
87
88
89
90
91
92
93

    @property
    def loaded(self) -> bool:
        return self.engine is not None

    def load_model(self, data) -> Generator[str, None, None]:
        get = lambda elem_id: data[self.manager.get_elem_by_id(elem_id)]
chenych's avatar
chenych committed
94
95
        lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
        finetuning_type, checkpoint_path = get("top.finetuning_type"), get("top.checkpoint_path")
chenych's avatar
chenych committed
96
97
        user_config = load_config()

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
98
99
100
        error = ""
        if self.loaded:
            error = ALERTS["err_exists"][lang]
chenych's avatar
chenych committed
101
        elif not model_name:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
102
            error = ALERTS["err_no_model"][lang]
chenych's avatar
chenych committed
103
        elif not model_path:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
104
105
106
107
108
109
110
111
112
113
114
            error = ALERTS["err_no_path"][lang]
        elif self.demo_mode:
            error = ALERTS["err_demo"][lang]

        if error:
            gr.Warning(error)
            yield error
            return

        yield ALERTS["info_loading"][lang]
        args = dict(
chenych's avatar
chenych committed
115
            model_name_or_path=model_path,
chenych's avatar
chenych committed
116
            cache_dir=user_config.get("cache_dir", None),
chenych's avatar
chenych committed
117
            finetuning_type=finetuning_type,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
118
            template=get("top.template"),
chenych's avatar
chenych committed
119
            rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") != "none" else None,
chenych's avatar
chenych committed
120
            flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
121
            use_unsloth=(get("top.booster") == "unsloth"),
chenych's avatar
chenych committed
122
            enable_liger_kernel=(get("top.booster") == "liger_kernel"),
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
123
            infer_backend=get("infer.infer_backend"),
chenych's avatar
chenych committed
124
            infer_dtype=get("infer.infer_dtype"),
chenych's avatar
chenych committed
125
            vllm_enforce_eager=True,
luopl's avatar
luopl committed
126
            trust_remote_code=True,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
127
128
        )

chenych's avatar
chenych committed
129
        # checkpoints
chenych's avatar
chenych committed
130
131
132
133
134
135
136
137
        if checkpoint_path:
            if finetuning_type in PEFT_METHODS:  # list
                args["adapter_name_or_path"] = ",".join(
                    [get_save_dir(model_name, finetuning_type, adapter) for adapter in checkpoint_path]
                )
            else:  # str
                args["model_name_or_path"] = get_save_dir(model_name, finetuning_type, checkpoint_path)

chenych's avatar
chenych committed
138
139
140
141
142
143
        # quantization
        if get("top.quantization_bit") != "none":
            args["quantization_bit"] = int(get("top.quantization_bit"))
            args["quantization_method"] = get("top.quantization_method")
            args["double_quantization"] = not is_torch_npu_available()

chenych's avatar
chenych committed
144
        super().__init__(args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        yield ALERTS["info_loaded"][lang]

    def unload_model(self, data) -> Generator[str, None, None]:
        lang = data[self.manager.get_elem_by_id("top.lang")]

        if self.demo_mode:
            gr.Warning(ALERTS["err_demo"][lang])
            yield ALERTS["err_demo"][lang]
            return

        yield ALERTS["info_unloading"][lang]
        self.engine = None
        torch_gc()
        yield ALERTS["info_unloaded"][lang]

chenych's avatar
chenych committed
160
    @staticmethod
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
161
    def append(
chenych's avatar
chenych committed
162
163
        chatbot: list[dict[str, str]],
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
164
165
        role: str,
        query: str,
chenych's avatar
chenych committed
166
        escape_html: bool,
chenych's avatar
chenych committed
167
168
    ) -> tuple[list[dict[str, str]], list[dict[str, str]], str]:
        r"""Add the user input to chatbot.
chenych's avatar
chenych committed
169
170
171
172
173
174
175
176
177

        Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
        Output: infer.chatbot, infer.messages, infer.query
        """
        return (
            chatbot + [{"role": "user", "content": _escape_html(query) if escape_html else query}],
            messages + [{"role": role, "content": query}],
            "",
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
178
179
180

    def stream(
        self,
chenych's avatar
chenych committed
181
182
        chatbot: list[dict[str, str]],
        messages: list[dict[str, str]],
chenych's avatar
chenych committed
183
        lang: str,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
184
185
        system: str,
        tools: str,
luopl's avatar
luopl committed
186
187
        image: Optional[Any],
        video: Optional[Any],
chenych's avatar
chenych committed
188
        audio: Optional[Any],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
189
190
191
        max_new_tokens: int,
        top_p: float,
        temperature: float,
chenych's avatar
chenych committed
192
193
        skip_special_tokens: bool,
        escape_html: bool,
chenych's avatar
chenych committed
194
195
    ) -> Generator[tuple[list[dict[str, str]], list[dict[str, str]]], None, None]:
        r"""Generate output text in stream.
chenych's avatar
chenych committed
196
197
198
199
200

        Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
        Output: infer.chatbot, infer.messages
        """
        chatbot.append({"role": "assistant", "content": ""})
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
201
202
        response = ""
        for new_text in self.stream_chat(
luopl's avatar
luopl committed
203
204
205
206
207
            messages,
            system,
            tools,
            images=[image] if image else None,
            videos=[video] if video else None,
chenych's avatar
chenych committed
208
            audios=[audio] if audio else None,
luopl's avatar
luopl committed
209
210
211
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            temperature=temperature,
chenych's avatar
chenych committed
212
            skip_special_tokens=skip_special_tokens,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
213
214
215
        ):
            response += new_text
            if tools:
chenych's avatar
chenych committed
216
                result = self.engine.template.extract_tool(response)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
217
218
219
            else:
                result = response

chenych's avatar
chenych committed
220
            if isinstance(result, list):
luopl's avatar
luopl committed
221
                tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
chenych's avatar
chenych committed
222
                tool_calls = json.dumps(tool_calls, ensure_ascii=False)
chenych's avatar
chenych committed
223
224
                output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
                bot_text = "```json\n" + tool_calls + "\n```"
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
225
226
            else:
                output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
chenych's avatar
chenych committed
227
                bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
228

chenych's avatar
chenych committed
229
            chatbot[-1] = {"role": "assistant", "content": bot_text}
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
230
            yield chatbot, output_messages