chat_model.py 6.69 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
import asyncio
chenych's avatar
chenych committed
19
import os
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence

chenych's avatar
chenych committed
23
from ..extras.constants import EngineName
chenych's avatar
chenych committed
24
from ..extras.misc import torch_gc
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
25
26
27
28
29
30
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine


if TYPE_CHECKING:
chenych's avatar
chenych committed
31
    from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
33
34
    from .base_engine import BaseEngine, Response


chenych's avatar
chenych committed
35
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
36
37
38
39
40
    asyncio.set_event_loop(loop)
    loop.run_forever()


class ChatModel:
luopl's avatar
luopl committed
41
42
43
44
45
46
47
48
    r"""
    General class for chat models. Backed by huggingface or vllm engines.

    Supports both sync and async methods.
    Sync methods: chat(), stream_chat() and get_scores().
    Async methods: achat(), astream_chat() and aget_scores().
    """

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
49
50
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
        model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
chenych's avatar
chenych committed
51
        if model_args.infer_backend == EngineName.HF:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
52
            self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
chenych's avatar
chenych committed
53
        elif model_args.infer_backend == EngineName.VLLM:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
54
55
            self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
        else:
luopl's avatar
luopl committed
56
            raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
57
58
59
60
61
62
63
64
65
66

        self._loop = asyncio.new_event_loop()
        self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
        self._thread.start()

    def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
67
68
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
69
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
70
71
        **input_kwargs,
    ) -> List["Response"]:
luopl's avatar
luopl committed
72
73
74
75
        r"""
        Gets a list of responses of the chat model.
        """
        task = asyncio.run_coroutine_threadsafe(
chenych's avatar
chenych committed
76
            self.achat(messages, system, tools, images, videos, audios, **input_kwargs), self._loop
luopl's avatar
luopl committed
77
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
78
79
80
81
82
83
84
        return task.result()

    async def achat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
85
86
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
87
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
88
89
        **input_kwargs,
    ) -> List["Response"]:
luopl's avatar
luopl committed
90
91
92
        r"""
        Asynchronously gets a list of responses of the chat model.
        """
chenych's avatar
chenych committed
93
        return await self.engine.chat(messages, system, tools, images, videos, audios, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
94
95
96
97
98
99

    def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
100
101
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
102
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
103
104
        **input_kwargs,
    ) -> Generator[str, None, None]:
luopl's avatar
luopl committed
105
106
107
        r"""
        Gets the response token-by-token of the chat model.
        """
chenych's avatar
chenych committed
108
        generator = self.astream_chat(messages, system, tools, images, videos, audios, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
109
110
111
112
113
114
115
116
117
118
119
120
        while True:
            try:
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
                yield task.result()
            except StopAsyncIteration:
                break

    async def astream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
121
122
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
chenych's avatar
chenych committed
123
        audios: Optional[Sequence["AudioInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
124
125
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
luopl's avatar
luopl committed
126
127
128
        r"""
        Asynchronously gets the response token-by-token of the chat model.
        """
chenych's avatar
chenych committed
129
130
131
        async for new_token in self.engine.stream_chat(
            messages, system, tools, images, videos, audios, **input_kwargs
        ):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
132
133
134
135
136
137
138
            yield new_token

    def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
luopl's avatar
luopl committed
139
140
141
        r"""
        Gets a list of scores of the reward model.
        """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
142
143
144
145
146
147
148
149
        task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
        return task.result()

    async def aget_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
luopl's avatar
luopl committed
150
151
152
        r"""
        Asynchronously gets a list of scores of the reward model.
        """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
153
        return await self.engine.get_scores(batch_input, **input_kwargs)
chenych's avatar
chenych committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193


def run_chat() -> None:
    if os.name != "nt":
        try:
            import readline  # noqa: F401
        except ImportError:
            print("Install `readline` for a better experience.")

    chat_model = ChatModel()
    messages = []
    print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

    while True:
        try:
            query = input("\nUser: ")
        except UnicodeDecodeError:
            print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
            continue
        except Exception:
            raise

        if query.strip() == "exit":
            break

        if query.strip() == "clear":
            messages = []
            torch_gc()
            print("History has been removed.")
            continue

        messages.append({"role": "user", "content": query})
        print("Assistant: ", end="", flush=True)

        response = ""
        for new_text in chat_model.stream_chat(messages):
            print(new_text, end="", flush=True)
            response += new_text
        print()
        messages.append({"role": "assistant", "content": response})