chat_model.py 5.25 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
18
import asyncio
chenych's avatar
chenych committed
19
import os
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
20
21
22
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence

chenych's avatar
chenych committed
23
from ..extras.misc import torch_gc
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
24
25
26
27
28
29
from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine
from .vllm_engine import VllmEngine


if TYPE_CHECKING:
chenych's avatar
chenych committed
30
31
    from numpy.typing import NDArray

Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
32
33
34
    from .base_engine import BaseEngine, Response


chenych's avatar
chenych committed
35
def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    asyncio.set_event_loop(loop)
    loop.run_forever()


class ChatModel:
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
        model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
        if model_args.infer_backend == "huggingface":
            self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
        elif model_args.infer_backend == "vllm":
            self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
        else:
            raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))

        self._loop = asyncio.new_event_loop()
        self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
        self._thread.start()

    def chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
59
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
60
61
        **input_kwargs,
    ) -> List["Response"]:
chenych's avatar
chenych committed
62
        task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
63
64
65
66
67
68
69
        return task.result()

    async def achat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
70
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
71
72
        **input_kwargs,
    ) -> List["Response"]:
chenych's avatar
chenych committed
73
        return await self.engine.chat(messages, system, tools, image, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
74
75
76
77
78
79

    def stream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
80
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
81
82
        **input_kwargs,
    ) -> Generator[str, None, None]:
chenych's avatar
chenych committed
83
        generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
84
85
86
87
88
89
90
91
92
93
94
95
        while True:
            try:
                task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
                yield task.result()
            except StopAsyncIteration:
                break

    async def astream_chat(
        self,
        messages: Sequence[Dict[str, str]],
        system: Optional[str] = None,
        tools: Optional[str] = None,
chenych's avatar
chenych committed
96
        image: Optional["NDArray"] = None,
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
97
98
        **input_kwargs,
    ) -> AsyncGenerator[str, None]:
chenych's avatar
chenych committed
99
        async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
Rayyyyy's avatar
V0.6.3  
Rayyyyy committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            yield new_token

    def get_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
        return task.result()

    async def aget_scores(
        self,
        batch_input: List[str],
        **input_kwargs,
    ) -> List[float]:
        return await self.engine.get_scores(batch_input, **input_kwargs)
chenych's avatar
chenych committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155


def run_chat() -> None:
    if os.name != "nt":
        try:
            import readline  # noqa: F401
        except ImportError:
            print("Install `readline` for a better experience.")

    chat_model = ChatModel()
    messages = []
    print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

    while True:
        try:
            query = input("\nUser: ")
        except UnicodeDecodeError:
            print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
            continue
        except Exception:
            raise

        if query.strip() == "exit":
            break

        if query.strip() == "clear":
            messages = []
            torch_gc()
            print("History has been removed.")
            continue

        messages.append({"role": "user", "content": query})
        print("Assistant: ", end="", flush=True)

        response = ""
        for new_text in chat_model.stream_chat(messages):
            print(new_text, end="", flush=True)
            response += new_text
        print()
        messages.append({"role": "assistant", "content": response})