client.py 3.1 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""

This is a client part of composite_demo.
We provide two clients, HFClient and VLLMClient, which are used to interact with the model.
The HFClient is used to interact with the  transformers backend, and the VLLMClient is used to interact with the VLLM model.

"""

import json
from collections.abc import Generator
from copy import deepcopy
from enum import Enum, auto
from typing import Protocol

import streamlit as st

from conversation import Conversation, build_system_prompt
from tools.tool_registry import ALL_TOOLS


class ClientType(Enum):
    HF = auto()
    VLLM = auto()


class Client(Protocol):
    def __init__(self, model_path: str): ...

    def generate_stream(
        self,
        tools: list[dict],
        history: list[Conversation],
        **parameters,
    ) -> Generator[tuple[str | dict, list[dict]]]: ...


def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
    chat_history = []
    if len(tools) > 0:
        chat_history.append(
            {"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)}
        )

    for conversation in history:
        role = str(conversation.role).removeprefix("<|").removesuffix("|>")
        item = {
            "role": role,
            "content": conversation.content,
        }
        if conversation.metadata:
            item["metadata"] = conversation.metadata
        # Only append image for user
        if role == "user" and conversation.image:
            item["image"] = conversation.image
        chat_history.append(item)

    return chat_history


def process_response(output, history):
    content = ""
    history = deepcopy(history)
    for response in output.split("<|assistant|>"):
        if "\n" in response:
            metadata, content = response.split("\n", maxsplit=1)
        else:
            metadata, content = "", response
        if not metadata.strip():
            content = content.strip()
            history.append({"role": "assistant", "metadata": metadata, "content": content})
            content = content.replace("[[训练时间]]", "2023年")
        else:
            history.append({"role": "assistant", "metadata": metadata, "content": content})
            if history[0]["role"] == "system" and "tools" in history[0]:
                parameters = json.loads(content)
                content = {"name": metadata.strip(), "parameters": parameters}
            else:
                content = {"name": metadata.strip(), "content": content}
    return content, history


# glm-4v-9b is not available in vLLM backend, use HFClient instead.
@st.cache_resource(max_entries=1, show_spinner="Loading model...")
def get_client(model_path, typ: ClientType) -> Client:
    match typ:
        case ClientType.HF:
            from clients.hf import HFClient

            return HFClient(model_path)
        case ClientType.VLLM:
            try:
                from clients.vllm import VLLMClient
            except ImportError as e:
                e.msg += "; did you forget to install vLLM?"
                raise
            return VLLMClient(model_path)

    raise NotImplementedError(f"Client type {typ} is not supported.")