vertexai.py 4.74 KB
Newer Older
shiyi.c_98's avatar
shiyi.c_98 committed
1
2
import os
import warnings
3
from typing import Optional
shiyi.c_98's avatar
shiyi.c_98 committed
4

Ying Sheng's avatar
Ying Sheng committed
5
from sglang.lang.backend.base_backend import BaseBackend
shiyi.c_98's avatar
shiyi.c_98 committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams

try:
    import vertexai
    from vertexai.preview.generative_models import (
        GenerationConfig,
        GenerativeModel,
        Image,
    )
except ImportError as e:
    GenerativeModel = e


21
class VertexAI(BaseBackend):
Liangsheng Yin's avatar
Liangsheng Yin committed
22
    def __init__(self, model_name, safety_settings=None):
shiyi.c_98's avatar
shiyi.c_98 committed
23
24
25
26
27
28
        super().__init__()

        if isinstance(GenerativeModel, Exception):
            raise GenerativeModel

        project_id = os.environ["GCP_PROJECT_ID"]
29
        location = os.environ.get("GCP_LOCATION")
shiyi.c_98's avatar
shiyi.c_98 committed
30
31
32
33
        vertexai.init(project=project_id, location=location)

        self.model_name = model_name
        self.chat_template = get_chat_template("default")
34
        self.safety_settings = safety_settings
shiyi.c_98's avatar
shiyi.c_98 committed
35
36
37
38
39
40
41
42
43
44

    def get_chat_template(self):
        return self.chat_template

    def generate(
        self,
        s: StreamExecutor,
        sampling_params: SglSamplingParams,
    ):
        if s.messages_:
45
            prompt = self.messages_to_vertexai_input(s.messages_)
shiyi.c_98's avatar
shiyi.c_98 committed
46
47
48
        else:
            # single-turn
            prompt = (
49
                self.text_to_vertexai_input(s.text_, s.cur_images)
shiyi.c_98's avatar
shiyi.c_98 committed
50
51
52
53
54
                if s.cur_images
                else s.text_
            )
        ret = GenerativeModel(self.model_name).generate_content(
            prompt,
55
            generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
56
            safety_settings=self.safety_settings,
shiyi.c_98's avatar
shiyi.c_98 committed
57
58
59
60
61
62
63
64
65
66
67
68
        )

        comp = ret.text

        return comp, {}

    def generate_stream(
        self,
        s: StreamExecutor,
        sampling_params: SglSamplingParams,
    ):
        if s.messages_:
69
            prompt = self.messages_to_vertexai_input(s.messages_)
shiyi.c_98's avatar
shiyi.c_98 committed
70
71
72
        else:
            # single-turn
            prompt = (
73
                self.text_to_vertexai_input(s.text_, s.cur_images)
shiyi.c_98's avatar
shiyi.c_98 committed
74
75
76
77
78
79
                if s.cur_images
                else s.text_
            )
        generator = GenerativeModel(self.model_name).generate_content(
            prompt,
            stream=True,
80
            generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
81
            safety_settings=self.safety_settings,
shiyi.c_98's avatar
shiyi.c_98 committed
82
83
84
85
        )
        for ret in generator:
            yield ret.text, {}

86
    def text_to_vertexai_input(self, text, images):
shiyi.c_98's avatar
shiyi.c_98 committed
87
88
89
90
91
92
93
94
95
96
97
98
99
        input = []
        # split with image token
        text_segs = text.split(self.chat_template.image_token)
        for image_path, image_base64_data in images:
            text_seg = text_segs.pop(0)
            if text_seg != "":
                input.append(text_seg)
            input.append(Image.from_bytes(image_base64_data))
        text_seg = text_segs.pop(0)
        if text_seg != "":
            input.append(text_seg)
        return input

100
101
102
    def messages_to_vertexai_input(self, messages):
        vertexai_message = []
        # from openai message format to vertexai message format
shiyi.c_98's avatar
shiyi.c_98 committed
103
104
105
106
107
108
109
        for msg in messages:
            if isinstance(msg["content"], str):
                text = msg["content"]
            else:
                text = msg["content"][0]["text"]

            if msg["role"] == "system":
110
111
                warnings.warn("Warning: system prompt is not supported in VertexAI.")
                vertexai_message.append(
shiyi.c_98's avatar
shiyi.c_98 committed
112
113
114
115
116
                    {
                        "role": "user",
                        "parts": [{"text": "System prompt: " + text}],
                    }
                )
117
                vertexai_message.append(
shiyi.c_98's avatar
shiyi.c_98 committed
118
119
120
121
122
123
124
                    {
                        "role": "model",
                        "parts": [{"text": "Understood."}],
                    }
                )
                continue
            if msg["role"] == "user":
125
                vertexai_msg = {
shiyi.c_98's avatar
shiyi.c_98 committed
126
127
128
129
                    "role": "user",
                    "parts": [{"text": text}],
                }
            elif msg["role"] == "assistant":
130
                vertexai_msg = {
shiyi.c_98's avatar
shiyi.c_98 committed
131
132
133
134
135
136
137
138
                    "role": "model",
                    "parts": [{"text": text}],
                }

            # images
            if isinstance(msg["content"], list) and len(msg["content"]) > 1:
                for image in msg["content"][1:]:
                    assert image["type"] == "image_url"
139
                    vertexai_msg["parts"].append(
shiyi.c_98's avatar
shiyi.c_98 committed
140
141
142
143
144
145
146
147
                        {
                            "inline_data": {
                                "data": image["image_url"]["url"].split(",")[1],
                                "mime_type": "image/jpeg",
                            }
                        }
                    )

148
149
            vertexai_message.append(vertexai_msg)
        return vertexai_message