chat_model.py 6.4 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
import asyncio
chenych's avatar
chenych committed
19
import os
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence

chenych's avatar
chenych committed
23
from ..extras.misc import torch_gc
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25
26
27
28
29
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine


if TYPE_CHECKING:
luopl's avatar
luopl committed
30
    from ..data.mm_plugin import ImageInput, VideoInput
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
31
32
33
    from .base_engine import BaseEngine, Response


chenych's avatar
chenych committed
34
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
35
36
37
38
39
    asyncio.set_event_loop(loop)
    loop.run_forever()


class ChatModel:
luopl's avatar
luopl committed
40
41
42
43
44
45
46
47
    r"""
    General class for chat models. Backed by huggingface or vllm engines.

    Supports both sync and async methods.
    Sync methods: chat(), stream_chat() and get_scores().
    Async methods: achat(), astream_chat() and aget_scores().
    """

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
48
49
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
        model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
luopl's avatar
luopl committed
50
        self.engine_type = model_args.infer_backend
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
51
52
53
54
55
        if model_args.infer_backend == "huggingface":
            self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
        elif model_args.infer_backend == "vllm":
            self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
        else:
luopl's avatar
luopl committed
56
            raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
57
58
59
60
61
62
63
64
65
66

        self._loop = asyncio.new_event_loop()
        self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
        self._thread.start()

    def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
67
68
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
69
70
        **input_kwargs,
    ) -> List["Response"]:
luopl's avatar
luopl committed
71
72
73
74
        r"""
        Gets a list of responses of the chat model.
        """
        task = asyncio.run_coroutine_threadsafe(
luopl's avatar
luopl committed
75
            self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
luopl's avatar
luopl committed
76
        )
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
77
78
79
80
81
82
83
        return task.result()

    async def achat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
84
85
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
86
87
        **input_kwargs,
    ) -> List["Response"]:
luopl's avatar
luopl committed
88
89
90
        r"""
        Asynchronously gets a list of responses of the chat model.
        """
luopl's avatar
luopl committed
91
        return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
92
93
94
95
96
97

    def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
98
99
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
100
101
        **input_kwargs,
    ) -> Generator[str, None, None]:
luopl's avatar
luopl committed
102
103
104
        r"""
        Gets the response token-by-token of the chat model.
        """
luopl's avatar
luopl committed
105
        generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
106
107
108
109
110
111
112
113
114
115
116
117
        while True:
            try:
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
                yield task.result()
            except StopAsyncIteration:
                break

    async def astream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
luopl's avatar
luopl committed
118
119
        images: Optional[Sequence["ImageInput"]] = None,
        videos: Optional[Sequence["VideoInput"]] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
120
121
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
luopl's avatar
luopl committed
122
123
124
        r"""
        Asynchronously gets the response token-by-token of the chat model.
        """
luopl's avatar
luopl committed
125
        async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
126
127
128
129
130
131
132
            yield new_token

    def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
luopl's avatar
luopl committed
133
134
135
        r"""
        Gets a list of scores of the reward model.
        """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
136
137
138
139
140
141
142
143
        task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
        return task.result()

    async def aget_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
luopl's avatar
luopl committed
144
145
146
        r"""
        Asynchronously gets a list of scores of the reward model.
        """
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
147
        return await self.engine.get_scores(batch_input, **input_kwargs)
chenych's avatar
chenych committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187


def run_chat() -> None:
    if os.name != "nt":
        try:
            import readline  # noqa: F401
        except ImportError:
            print("Install `readline` for a better experience.")

    chat_model = ChatModel()
    messages = []
    print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

    while True:
        try:
            query = input("\nUser: ")
        except UnicodeDecodeError:
            print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
            continue
        except Exception:
            raise

        if query.strip() == "exit":
            break

        if query.strip() == "clear":
            messages = []
            torch_gc()
            print("History has been removed.")
            continue

        messages.append({"role": "user", "content": query})
        print("Assistant: ", end="", flush=True)

        response = ""
        for new_text in chat_model.stream_chat(messages):
            print(new_text, end="", flush=True)
            response += new_text
        print()
        messages.append({"role": "assistant", "content": response})