"vscode:/vscode.git/clone" did not exist on "6c5ccb11f993ccc88c4761b8c31e0fefcbc1900f"
client.py 3.32 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""

This is a client part of composite_demo.
We provide two clients, HFClient and VLLMClient, which are used to interact with the model.
The HFClient is used to interact with the  transformers backend, and the VLLMClient is used to interact with the VLLM model.

"""

import json
from collections.abc import Generator
from copy import deepcopy
from enum import Enum, auto
from typing import Protocol

import streamlit as st
from conversation import Conversation, build_system_prompt
from tools.tool_registry import ALL_TOOLS


class ClientType(Enum):
    HF = auto()
    VLLM = auto()
Rayyyyy's avatar
Rayyyyy committed
23
    API = auto()
Rayyyyy's avatar
Rayyyyy committed
24
25
26
27
28
29
30
31
32
33
34
35
36


class Client(Protocol):
    def __init__(self, model_path: str): ...

    def generate_stream(
        self,
        tools: list[dict],
        history: list[Conversation],
        **parameters,
    ) -> Generator[tuple[str | dict, list[dict]]]: ...


Rayyyyy's avatar
Rayyyyy committed
37
def process_input(history: list[dict], tools: list[dict], role_name_replace: dict = None) -> list[dict]:
Rayyyyy's avatar
Rayyyyy committed
38
    chat_history = []
Rayyyyy's avatar
Rayyyyy committed
39
40
    # if len(tools) > 0:
    chat_history.append({"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)})
Rayyyyy's avatar
Rayyyyy committed
41
42
43

    for conversation in history:
        role = str(conversation.role).removeprefix("<|").removesuffix("|>")
Rayyyyy's avatar
Rayyyyy committed
44
45
        if role_name_replace:
            role = role_name_replace.get(role, role)
Rayyyyy's avatar
Rayyyyy committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        item = {
            "role": role,
            "content": conversation.content,
        }
        if conversation.metadata:
            item["metadata"] = conversation.metadata
        # Only append image for user
        if role == "user" and conversation.image:
            item["image"] = conversation.image
        chat_history.append(item)

    return chat_history


def process_response(output, history):
    content = ""
    history = deepcopy(history)
    for response in output.split("<|assistant|>"):
        if "\n" in response:
            metadata, content = response.split("\n", maxsplit=1)
        else:
            metadata, content = "", response
        if not metadata.strip():
            content = content.strip()
            history.append({"role": "assistant", "metadata": metadata, "content": content})
            content = content.replace("[[训练时间]]", "2023年")
        else:
            history.append({"role": "assistant", "metadata": metadata, "content": content})
            if history[0]["role"] == "system" and "tools" in history[0]:
                parameters = json.loads(content)
                content = {"name": metadata.strip(), "parameters": parameters}
            else:
                content = {"name": metadata.strip(), "content": content}
    return content, history


# glm-4v-9b is not available in vLLM backend, use HFClient instead.
@st.cache_resource(max_entries=1, show_spinner="Loading model...")
def get_client(model_path, typ: ClientType) -> Client:
    match typ:
        case ClientType.HF:
            from clients.hf import HFClient

            return HFClient(model_path)
        case ClientType.VLLM:
            try:
                from clients.vllm import VLLMClient
            except ImportError as e:
                e.msg += "; did you forget to install vLLM?"
                raise
            return VLLMClient(model_path)
Rayyyyy's avatar
Rayyyyy committed
97
98
99
100
        case ClientType.API:
            from clients.openai import APIClient

            return APIClient(model_path)
Rayyyyy's avatar
Rayyyyy committed
101
102

    raise NotImplementedError(f"Client type {typ} is not supported.")