chat_model.py 6.71 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 THUDM and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
import asyncio
chenych's avatar
chenych committed
19
import os
chenych's avatar
chenych committed
20
from collections.abc import AsyncGenerator, Generator
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
21
from threading import Thread
chenych's avatar
chenych committed
22
from typing import TYPE_CHECKING, Any, Optional
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
23

chenych's avatar
chenych committed
24
from ..extras.constants import EngineName
chenych's avatar
chenych committed
25
from ..extras.misc import torch_gc
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
26
27
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
chenych's avatar
chenych committed
28
from .sglang_engine import SGLangEngine
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
29
30
31
32
from .vllm_engine import VllmEngine


if TYPE_CHECKING:
chenych's avatar
chenych committed
33
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
34
35
36
    from .base_engine import BaseEngine, Response


chenych's avatar
chenych committed
37
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
38
39
40
41
42
    asyncio.set_event_loop(loop)
    loop.run_forever()


class ChatModel:
chenych's avatar
chenych committed
43
    r"""General class for chat models. Backed by huggingface or vllm engines.
luopl's avatar
luopl committed
44
45
46
47
48
49

    Supports both sync and async methods.
    Sync methods: chat(), stream_chat() and get_scores().
    Async methods: achat(), astream_chat() and aget_scores().
    """

chenych's avatar
chenych committed
50
    def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
51
        model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
chenych's avatar
chenych committed
52
        if model_args.infer_backend == EngineName.HF:
chenych's avatar
chenych committed
53
            self.engine: BaseEngine = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
chenych's avatar
chenych committed
54
        elif model_args.infer_backend == EngineName.VLLM:
chenych's avatar
chenych committed
55
56
57
            self.engine: BaseEngine = VllmEngine(model_args, data_args, finetuning_args, generating_args)
        elif model_args.infer_backend == EngineName.SGLANG:
            self.engine: BaseEngine = SGLangEngine(model_args, data_args, finetuning_args, generating_args)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
58
        else:
luopl's avatar
luopl committed
59
            raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
60
61
62
63
64
65
66

        self._loop = asyncio.new_event_loop()
        self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
        self._thread.start()

    def chat(
        self,
chenych's avatar
chenych committed
67
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
68
69
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
70
71
72
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
73
        **input_kwargs,
chenych's avatar
chenych committed
74
75
    ) -> list["Response"]:
        r"""Get a list of responses of the chat model."""
luopl's avatar
luopl committed
76
        task = asyncio.run_coroutine_threadsafe(
chenych's avatar
chenych committed
77
            self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
luopl's avatar
luopl committed
78
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
79
80
81
82
        return task.result()

    async def achat(
        self,
chenych's avatar
chenych committed
83
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
84
85
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
86
87
88
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
89
        **input_kwargs,
chenych's avatar
chenych committed
90
91
    ) -> list["Response"]:
        r"""Asynchronously get a list of responses of the chat model."""
chenych's avatar
chenych committed
92
        return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
93
94
95

    def stream_chat(
        self,
chenych's avatar
chenych committed
96
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
97
98
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
99
100
101
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
102
103
        **input_kwargs,
    ) -> Generator[str, None, None]:
chenych's avatar
chenych committed
104
        r"""Get the response token-by-token of the chat model."""
chenych's avatar
chenych committed
105
        generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
107
108
109
110
111
112
113
114
        while True:
            try:
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
                yield task.result()
            except StopAsyncIteration:
                break

    async def astream_chat(
        self,
chenych's avatar
chenych committed
115
        messages: list[dict[str, str]],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
116
117
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
118
119
120
        images: Optional[list["ImageInput"]] = None,
        videos: Optional[list["VideoInput"]] = None,
        audios: Optional[list["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
121
122
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
chenych's avatar
chenych committed
123
        r"""Asynchronously get the response token-by-token of the chat model."""
chenych's avatar
chenych committed
124
125
126
        async for new_token in self.engine.stream_chat(
            messages, system, tools, images, videos, audios, **input_kwargs
        ):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
127
128
129
130
            yield new_token

    def get_scores(
        self,
chenych's avatar
chenych committed
131
        batch_input: list[str],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
132
        **input_kwargs,
chenych's avatar
chenych committed
133
134
    ) -> list[float]:
        r"""Get a list of scores of the reward model."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
135
136
137
138
139
        task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
        return task.result()

    async def aget_scores(
        self,
chenych's avatar
chenych committed
140
        batch_input: list[str],
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
141
        **input_kwargs,
chenych's avatar
chenych committed
142
143
    ) -> list[float]:
        r"""Asynchronously get a list of scores of the reward model."""
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
144
        return await self.engine.get_scores(batch_input, **input_kwargs)
chenych's avatar
chenych committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184


def run_chat() -> None:
    if os.name != "nt":
        try:
            import readline  # noqa: F401
        except ImportError:
            print("Install `readline` for a better experience.")

    chat_model = ChatModel()
    messages = []
    print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

    while True:
        try:
            query = input("\nUser: ")
        except UnicodeDecodeError:
            print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
            continue
        except Exception:
            raise

        if query.strip() == "exit":
            break

        if query.strip() == "clear":
            messages = []
            torch_gc()
            print("History has been removed.")
            continue

        messages.append({"role": "user", "content": query})
        print("Assistant: ", end="", flush=True)

        response = ""
        for new_text in chat_model.stream_chat(messages):
            print(new_text, end="", flush=True)
            response += new_text
        print()
        messages.append({"role": "assistant", "content": response})