vertexai.py 4.71 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import warnings

from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

try:
    import vertexai
    from vertexai.preview.generative_models import (
        GenerationConfig,
        GenerativeModel,
        Image,
    )
except ImportError as e:
    GenerativeModel = e


class VertexAI(BaseBackend):
    def __init__(self, model_name, safety_settings=None):
        super().__init__()

        if isinstance(GenerativeModel, Exception):
            raise GenerativeModel

        project_id = os.environ["GCP_PROJECT_ID"]
        location = os.environ.get("GCP_LOCATION")
        vertexai.init(project=project_id, location=location)

        self.model_name = model_name
        self.chat_template = get_chat_template("default")
        self.safety_settings = safety_settings

    def get_chat_template(self):
        return self.chat_template

    def generate(
        self,
        s: StreamExecutor,
        sampling_params: SglSamplingParams,
    ):
        if s.messages_:
            prompt = self.messages_to_vertexai_input(s.messages_)
        else:
            # single-turn
            prompt = (
                self.text_to_vertexai_input(s.text_, s.cur_images)
                if s.cur_images
                else s.text_
            )
        ret = GenerativeModel(self.model_name).generate_content(
            prompt,
            generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
            safety_settings=self.safety_settings,
        )

        comp = ret.text

        return comp, {}

    def generate_stream(
        self,
        s: StreamExecutor,
        sampling_params: SglSamplingParams,
    ):
        if s.messages_:
            prompt = self.messages_to_vertexai_input(s.messages_)
        else:
            # single-turn
            prompt = (
                self.text_to_vertexai_input(s.text_, s.cur_images)
                if s.cur_images
                else s.text_
            )
        generator = GenerativeModel(self.model_name).generate_content(
            prompt,
            stream=True,
            generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
            safety_settings=self.safety_settings,
        )
        for ret in generator:
            yield ret.text, {}

    def text_to_vertexai_input(self, text, images):
        input = []
        # split with image token
        text_segs = text.split(self.chat_template.image_token)
        for image_path, image_base64_data in images:
            text_seg = text_segs.pop(0)
            if text_seg != "":
                input.append(text_seg)
            input.append(Image.from_bytes(image_base64_data))
        text_seg = text_segs.pop(0)
        if text_seg != "":
            input.append(text_seg)
        return input

    def messages_to_vertexai_input(self, messages):
        vertexai_message = []
        # from openai message format to vertexai message format
        for msg in messages:
            if isinstance(msg["content"], str):
                text = msg["content"]
            else:
                text = msg["content"][0]["text"]

            if msg["role"] == "system":
                warnings.warn("Warning: system prompt is not supported in VertexAI.")
                vertexai_message.append(
                    {
                        "role": "user",
                        "parts": [{"text": "System prompt: " + text}],
                    }
                )
                vertexai_message.append(
                    {
                        "role": "model",
                        "parts": [{"text": "Understood."}],
                    }
                )
                continue
            if msg["role"] == "user":
                vertexai_msg = {
                    "role": "user",
                    "parts": [{"text": text}],
                }
            elif msg["role"] == "assistant":
                vertexai_msg = {
                    "role": "model",
                    "parts": [{"text": text}],
                }

            # images
            if isinstance(msg["content"], list) and len(msg["content"]) > 1:
                for image in msg["content"][1:]:
                    assert image["type"] == "image_url"
                    vertexai_msg["parts"].append(
                        {
                            "inline_data": {
                                "data": image["image_url"]["url"].split(",")[1],
                                "mime_type": "image/jpeg",
                            }
                        }
                    )

            vertexai_message.append(vertexai_msg)
        return vertexai_message