runtime_endpoint.py 8.55 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
import json
Mingyi's avatar
Mingyi committed
2
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
3
4

import numpy as np
Liangsheng Yin's avatar
Liangsheng Yin committed
5

Lianmin Zheng's avatar
Lianmin Zheng committed
6
from sglang.global_config import global_config
Ying Sheng's avatar
Ying Sheng committed
7
from sglang.lang.backend.base_backend import BaseBackend
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
Mingyi's avatar
Mingyi committed
10
from sglang.lang.ir import SglSamplingParams
11
from sglang.utils import http_request
Lianmin Zheng's avatar
Lianmin Zheng committed
12
13
14


class RuntimeEndpoint(BaseBackend):
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
17
18
19
20
    def __init__(
        self,
        base_url: str,
        api_key: Optional[str] = None,
        verify: Optional[str] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24
        super().__init__()
        self.support_concate_and_append = True

        self.base_url = base_url
25
        self.api_key = api_key
26
        self.verify = verify
Lianmin Zheng's avatar
Lianmin Zheng committed
27

28
        res = http_request(
29
            self.base_url + "/get_model_info",
30
            api_key=self.api_key,
31
            verify=self.verify,
32
        )
33
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
34
35
36
        self.model_info = res.json()

        self.chat_template = get_chat_template_by_model_path(
37
38
            self.model_info["model_path"]
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
41
42

    def get_model_name(self):
        return self.model_info["model_path"]

Liangsheng Yin's avatar
Liangsheng Yin committed
43
44
45
    def flush_cache(self):
        res = http_request(
            self.base_url + "/flush_cache",
46
            api_key=self.api_key,
Liangsheng Yin's avatar
Liangsheng Yin committed
47
48
            verify=self.verify,
        )
49
        self._assert_success(res)
Liangsheng Yin's avatar
Liangsheng Yin committed
50
51
52
53

    def get_server_args(self):
        res = http_request(
            self.base_url + "/get_server_args",
54
            api_key=self.api_key,
Liangsheng Yin's avatar
Liangsheng Yin committed
55
56
            verify=self.verify,
        )
57
        self._assert_success(res)
Liangsheng Yin's avatar
Liangsheng Yin committed
58
59
        return res.json()

Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
66
    def get_chat_template(self):
        return self.chat_template

    def cache_prefix(self, prefix_str: str):
        res = http_request(
            self.base_url + "/generate",
            json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
Lianmin Zheng's avatar
Lianmin Zheng committed
67
            api_key=self.api_key,
68
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
69
        )
70
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
71
72

    def commit_lazy_operations(self, s: StreamExecutor):
73
74
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
        self._add_images(s, data)
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
        res = http_request(
            self.base_url + "/generate",
77
            json=data,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
            api_key=self.api_key,
79
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
        )
81
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
85

    def fill_image(self, s: StreamExecutor):
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
        self._add_images(s, data)
86
        res = http_request(
87
88
            self.base_url + "/generate",
            json=data,
89
            api_key=self.api_key,
90
            verify=self.verify,
91
        )
92
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
93
94
95
96

    def generate(
        self,
        s: StreamExecutor,
97
        sampling_params: SglSamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
98
99
100
101
102
103
    ):
        if sampling_params.dtype is None:
            data = {
                "text": s.text_,
                "sampling_params": {
                    "skip_special_tokens": global_config.skip_special_tokens_in_output,
104
                    "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
107
108
109
110
111
112
                    **sampling_params.to_srt_kwargs(),
                },
            }
        elif sampling_params.dtype in [int, "int"]:
            data = {
                "text": s.text_,
                "sampling_params": {
                    "skip_special_tokens": global_config.skip_special_tokens_in_output,
113
                    "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
116
117
118
119
120
                    "dtype": "int",
                    **sampling_params.to_srt_kwargs(),
                },
            }
        else:
            raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")

121
122
123
124
125
126
        for item in [
            "return_logprob",
            "logprob_start_len",
            "top_logprobs_num",
            "return_text_in_logprobs",
        ]:
127
128
129
130
            value = getattr(sampling_params, item, None)
            if value is not None:
                data[item] = value

Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
        self._add_images(s, data)

133
        res = http_request(
134
135
            self.base_url + "/generate",
            json=data,
136
            api_key=self.api_key,
137
            verify=self.verify,
138
        )
139
140
        self._assert_success(res)

Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
143
144
145
146
147
        obj = res.json()
        comp = obj["text"]
        return comp, obj["meta_info"]

    def generate_stream(
        self,
        s: StreamExecutor,
148
        sampling_params: SglSamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
149
150
151
152
153
154
    ):
        if sampling_params.dtype is None:
            data = {
                "text": s.text_,
                "sampling_params": {
                    "skip_special_tokens": global_config.skip_special_tokens_in_output,
155
                    "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
157
158
159
160
161
162
163
                    **sampling_params.to_srt_kwargs(),
                },
            }
        elif sampling_params.dtype in [int, "int"]:
            data = {
                "text": s.text_,
                "sampling_params": {
                    "skip_special_tokens": global_config.skip_special_tokens_in_output,
164
                    "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
Lianmin Zheng's avatar
Lianmin Zheng committed
165
166
167
168
169
170
171
                    "dtype": "int",
                    **sampling_params.to_srt_kwargs(),
                },
            }
        else:
            raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")

172
173
174
175
176
177
        for item in [
            "return_logprob",
            "logprob_start_len",
            "top_logprobs_num",
            "return_text_in_logprobs",
        ]:
178
179
180
181
            value = getattr(sampling_params, item, None)
            if value is not None:
                data[item] = value

Lianmin Zheng's avatar
Lianmin Zheng committed
182
183
184
        data["stream"] = True
        self._add_images(s, data)

185
        res = http_request(
186
187
188
            self.base_url + "/generate",
            json=data,
            stream=True,
Lianmin Zheng's avatar
Lianmin Zheng committed
189
            api_key=self.api_key,
190
            verify=self.verify,
191
        )
192
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
193
194
        pos = 0

195
        for chunk in res.iter_lines(decode_unicode=False):
196
197
198
199
200
            chunk = chunk.decode("utf-8")
            if chunk and chunk.startswith("data:"):
                if chunk == "data: [DONE]":
                    break
                data = json.loads(chunk[5:].strip("\n"))
201
                chunk_text = data["text"][pos:]
Lianmin Zheng's avatar
Lianmin Zheng committed
202
                meta_info = data["meta_info"]
203
204
                pos += len(chunk_text)
                yield chunk_text, meta_info
Lianmin Zheng's avatar
Lianmin Zheng committed
205
206
207
208
209
210
211
212
213
214
215
216

    def select(
        self,
        s: StreamExecutor,
        choices: List[str],
        temperature: float,
    ):
        assert temperature <= 1e-5

        # Cache common prefix
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
        self._add_images(s, data)
217
        res = http_request(
218
219
            self.base_url + "/generate",
            json=data,
220
            api_key=self.api_key,
221
            verify=self.verify,
222
        )
223
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
228
229
        prompt_len = res.json()["meta_info"]["prompt_tokens"]

        # Compute logprob
        data = {
            "text": [s.text_ + c for c in choices],
            "sampling_params": {"max_new_tokens": 0},
230
231
            "return_logprob": True,
            "logprob_start_len": max(prompt_len - 2, 0),
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
        }
        self._add_images(s, data)
234
        res = http_request(
235
236
            self.base_url + "/generate",
            json=data,
237
            api_key=self.api_key,
238
            verify=self.verify,
239
        )
240
        self._assert_success(res)
241
        obj = res.json()
Liangsheng Yin's avatar
Liangsheng Yin committed
242
        normalized_prompt_logprobs = [
243
244
            r["meta_info"]["normalized_prompt_logprob"] for r in obj
        ]
Liangsheng Yin's avatar
Liangsheng Yin committed
245
        decision = choices[np.argmax(normalized_prompt_logprobs)]
246
247
        input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
        output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
Liangsheng Yin's avatar
Liangsheng Yin committed
248
249
250
251

        return (
            decision,
            normalized_prompt_logprobs,
252
253
            input_token_logprobs,
            output_token_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
254
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
257
258
259

    def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
        res = http_request(
            self.base_url + "/concate_and_append_request",
            json={"src_rids": src_rids, "dst_rid": dst_rid},
260
            api_key=self.api_key,
261
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
262
        )
263
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
264
265
266
267
268

    def _add_images(self, s: StreamExecutor, data):
        if s.images_:
            assert len(s.images_) == 1, "Only support one image."
            data["image_data"] = s.images_[0][1]
269
270
271

    def _assert_success(self, res):
        if res.status_code != 200:
272
            raise RuntimeError(res.json())