chat_model.py 7.84 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 THUDM and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
import asyncio
chenych's avatar
chenych committed
19
import os
chenych's avatar
chenych committed
20
from collections.abc import AsyncGenerator, Generator
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
from threading import Thread
chenych's avatar
chenych committed
22
from typing import TYPE_CHECKING, Any, Optional
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
23

chenych's avatar
chenych committed
24
from ..extras.constants import EngineName
chenych's avatar
chenych committed
25
from ..extras.misc import torch_gc
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26
27
28
29
from ..hparams import get_infer_args


if TYPE_CHECKING:
chenych's avatar
chenych committed
30
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
31
32
33
    from .base_engine import BaseEngine, Response


chenych's avatar
chenych committed
34
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
38
39
    asyncio.set_event_loop(loop)
    loop.run_forever()


class ChatModel:
chenych's avatar
chenych committed
40
    r"""General class for chat models. Backed by huggingface or vllm engines.
luopl's avatar
luopl committed
41
42
43
44
45
46

    Supports both sync and async methods.
    Sync methods: chat(), stream_chat() and get_scores().
    Async methods: achat(), astream_chat() and aget_scores().
    """

chenych's avatar
chenych committed
47
    def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
48
        model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
shihm's avatar
uodata  
shihm committed
49

chenych's avatar
chenych committed
50
        if model_args.infer_backend == EngineName.HF:
shihm's avatar
uodata  
shihm committed
51
52
            from .hf_engine import HuggingfaceEngine

chenych's avatar
chenych committed
53
            self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
chenych's avatar
chenych committed
54
        elif model_args.infer_backend == EngineName.VLLM:
shihm's avatar
uodata  
shihm committed
55
56
57
58
59
60
61
62
63
            try:
                from .vllm_engine import VllmEngine

                self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
            except ImportError as e:
                raise ImportError(
                    "vLLM not install, you may need to run `pip install vllm`\n"
                    "or try to use HuggingFace backend: --infer_backend huggingface"
                ) from e
chenych's avatar
chenych committed
64
        elif model_args.infer_backend == EngineName.SGLANG:
shihm's avatar
uodata  
shihm committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            try:
                from .sglang_engine import SGLangEngine

                self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
            except ImportError as e:
                raise ImportError(
                    "SGLang not install, you may need to run `pip install sglang[all]`\n"
                    "or try to use HuggingFace backend: --infer_backend huggingface"
                ) from e
        elif model_args.infer_backend == EngineName.KT:
            try:
                from .kt_engine import KTransformersEngine

                self.engine: BaseEngine = KTransformersEngine(model_args, data_args, finetuning_args, generating_args)
            except ImportError as e:
                raise ImportError(
                    "KTransformers not install, you may need to run `pip install ktransformers`\n"
                    "or try to use HuggingFace backend: --infer_backend huggingface"
                ) from e
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
84
        else:
luopl's avatar
luopl committed
85
            raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
86
87
88
89
90
91
92

        self._loop = asyncio.new_event_loop()
        self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
        self._thread.start()

    def chat(
        self,
chenych's avatar
chenych committed
93
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
94
95
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
96
97
98
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
99
        **input_kwargs,
chenych's avatar
chenych committed
100
101
    ) -> list["Response"]:
        r"""Get a list of responses of the chat model."""
luopl's avatar
luopl committed
102
        task = asyncio.run_coroutine_threadsafe(
chenych's avatar
chenych committed
103
            self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
luopl's avatar
luopl committed
104
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
105
106
107
108
        return task.result()

    async def achat(
        self,
chenych's avatar
chenych committed
109
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
110
111
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
112
113
114
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
115
        **input_kwargs,
chenych's avatar
chenych committed
116
117
    ) -> list["Response"]:
        r"""Asynchronously get a list of responses of the chat model."""
chenych's avatar
chenych committed
118
        return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
119
120
121

    def stream_chat(
        self,
chenych's avatar
chenych committed
122
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
123
124
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
125
126
127
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
128
129
        **input_kwargs,
    ) -> Generator[str, None, None]:
chenych's avatar
chenych committed
130
        r"""Get the response token-by-token of the chat model."""
chenych's avatar
chenych committed
131
        generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
132
133
134
135
136
137
138
139
140
        while True:
            try:
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
                yield task.result()
            except StopAsyncIteration:
                break

    async def astream_chat(
        self,
chenych's avatar
chenych committed
141
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
142
143
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
144
145
146
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
147
148
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
chenych's avatar
chenych committed
149
        r"""Asynchronously get the response token-by-token of the chat model."""
chenych's avatar
chenych committed
150
151
152
        async for new_token in self.engine.stream_chat(
            messages, system, tools, images, videos, audios, **input_kwargs
        ):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
153
154
155
156
            yield new_token

    def get_scores(
        self,
chenych's avatar
chenych committed
157
        batch_input: list[str],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
158
        **input_kwargs,
chenych's avatar
chenych committed
159
160
    ) -> list[float]:
        r"""Get a list of scores of the reward model."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
161
162
163
164
165
        task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
        return task.result()

    async def aget_scores(
        self,
chenych's avatar
chenych committed
166
        batch_input: list[str],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
167
        **input_kwargs,
chenych's avatar
chenych committed
168
169
    ) -> list[float]:
        r"""Asynchronously get a list of scores of the reward model."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
170
        return await self.engine.get_scores(batch_input, **input_kwargs)
chenych's avatar
chenych committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210


def run_chat() -> None:
    if os.name != "nt":
        try:
            import readline  # noqa: F401
        except ImportError:
            print("Install `readline` for a better experience.")

    chat_model = ChatModel()
    messages = []
    print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

    while True:
        try:
            query = input("\nUser: ")
        except UnicodeDecodeError:
            print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
            continue
        except Exception:
            raise

        if query.strip() == "exit":
            break

        if query.strip() == "clear":
            messages = []
            torch_gc()
            print("History has been removed.")
            continue

        messages.append({"role": "user", "content": query})
        print("Assistant: ", end="", flush=True)

        response = ""
        for new_text in chat_model.stream_chat(messages):
            print(new_text, end="", flush=True)
            response += new_text
        print()
        messages.append({"role": "assistant", "content": response})