runtime_endpoint.py 10.3 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
import json
2
import warnings
Mingyi's avatar
Mingyi committed
3
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
4
5

from sglang.global_config import global_config
Ying Sheng's avatar
Ying Sheng committed
6
from sglang.lang.backend.base_backend import BaseBackend
7
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
8
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
Lianmin Zheng's avatar
Lianmin Zheng committed
9
from sglang.lang.interpreter import StreamExecutor
10
11
12
13
14
15
16
from sglang.lang.ir import (
    REGEX_BOOL,
    REGEX_FLOAT,
    REGEX_INT,
    REGEX_STR,
    SglSamplingParams,
)
17
from sglang.utils import http_request
Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
20


class RuntimeEndpoint(BaseBackend):
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
23
24
25
    def __init__(
        self,
        base_url: str,
        api_key: Optional[str] = None,
        verify: Optional[str] = None,
26
        chat_template_name: Optional[str] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
27
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
31
        super().__init__()
        self.support_concate_and_append = True

        self.base_url = base_url
32
        self.api_key = api_key
33
        self.verify = verify
Lianmin Zheng's avatar
Lianmin Zheng committed
34

35
        res = http_request(
36
            self.base_url + "/get_model_info",
37
            api_key=self.api_key,
38
            verify=self.verify,
39
        )
40
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
        self.model_info = res.json()

43
44
45
46
47
48
        if chat_template_name:
            self.chat_template = get_chat_template(chat_template_name)
        else:
            self.chat_template = get_chat_template_by_model_path(
                self.model_info["model_path"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
52

    def get_model_name(self):
        return self.model_info["model_path"]

Liangsheng Yin's avatar
Liangsheng Yin committed
53
54
55
    def flush_cache(self):
        res = http_request(
            self.base_url + "/flush_cache",
56
            api_key=self.api_key,
Liangsheng Yin's avatar
Liangsheng Yin committed
57
58
            verify=self.verify,
        )
59
        self._assert_success(res)
Liangsheng Yin's avatar
Liangsheng Yin committed
60
61
62
63

    def get_server_args(self):
        res = http_request(
            self.base_url + "/get_server_args",
64
            api_key=self.api_key,
Liangsheng Yin's avatar
Liangsheng Yin committed
65
66
            verify=self.verify,
        )
67
        self._assert_success(res)
Liangsheng Yin's avatar
Liangsheng Yin committed
68
69
        return res.json()

Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
74
75
76
    def get_chat_template(self):
        return self.chat_template

    def cache_prefix(self, prefix_str: str):
        res = http_request(
            self.base_url + "/generate",
            json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
Lianmin Zheng's avatar
Lianmin Zheng committed
77
            api_key=self.api_key,
78
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
        )
80
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82

    def commit_lazy_operations(self, s: StreamExecutor):
83
84
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
        self._add_images(s, data)
Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
        res = http_request(
            self.base_url + "/generate",
87
            json=data,
Lianmin Zheng's avatar
Lianmin Zheng committed
88
            api_key=self.api_key,
89
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
90
        )
91
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
92
93
94
95

    def fill_image(self, s: StreamExecutor):
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
        self._add_images(s, data)
96
        res = http_request(
97
98
            self.base_url + "/generate",
            json=data,
99
            api_key=self.api_key,
100
            verify=self.verify,
101
        )
102
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
        if sampling_params.dtype is None:
            return

        if sampling_params.stop == ():
            sampling_params.stop = []

        dtype_regex = None
        if sampling_params.dtype in ["int", int]:

            dtype_regex = REGEX_INT
            sampling_params.stop.extend([" ", "\n"])
        elif sampling_params.dtype in ["float", float]:

            dtype_regex = REGEX_FLOAT
            sampling_params.stop.extend([" ", "\n"])
        elif sampling_params.dtype in ["str", str]:

            dtype_regex = REGEX_STR
        elif sampling_params.dtype in ["bool", bool]:

            dtype_regex = REGEX_BOOL
        else:
            raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")

        if dtype_regex is not None and sampling_params.regex is not None:
            warnings.warn(
                f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
            )

        sampling_params.regex = dtype_regex

Lianmin Zheng's avatar
Lianmin Zheng committed
136
137
138
    def generate(
        self,
        s: StreamExecutor,
139
        sampling_params: SglSamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
140
    ):
141
142
143
144
145
146
147
148
149
        self._handle_dtype_to_regex(sampling_params)
        data = {
            "text": s.text_,
            "sampling_params": {
                "skip_special_tokens": global_config.skip_special_tokens_in_output,
                "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
                **sampling_params.to_srt_kwargs(),
            },
        }
Lianmin Zheng's avatar
Lianmin Zheng committed
150

151
152
153
154
155
156
        for item in [
            "return_logprob",
            "logprob_start_len",
            "top_logprobs_num",
            "return_text_in_logprobs",
        ]:
157
158
159
160
            value = getattr(sampling_params, item, None)
            if value is not None:
                data[item] = value

Lianmin Zheng's avatar
Lianmin Zheng committed
161
162
        self._add_images(s, data)

163
        res = http_request(
164
165
            self.base_url + "/generate",
            json=data,
166
            api_key=self.api_key,
167
            verify=self.verify,
168
        )
169
170
        self._assert_success(res)

Lianmin Zheng's avatar
Lianmin Zheng committed
171
172
173
174
175
176
177
        obj = res.json()
        comp = obj["text"]
        return comp, obj["meta_info"]

    def generate_stream(
        self,
        s: StreamExecutor,
178
        sampling_params: SglSamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
179
    ):
180
181
182
183
184
185
186
187
188
189
        self._handle_dtype_to_regex(sampling_params)

        data = {
            "text": s.text_,
            "sampling_params": {
                "skip_special_tokens": global_config.skip_special_tokens_in_output,
                "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
                **sampling_params.to_srt_kwargs(),
            },
        }
Lianmin Zheng's avatar
Lianmin Zheng committed
190

191
192
193
194
195
196
        for item in [
            "return_logprob",
            "logprob_start_len",
            "top_logprobs_num",
            "return_text_in_logprobs",
        ]:
197
198
199
200
            value = getattr(sampling_params, item, None)
            if value is not None:
                data[item] = value

Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
203
        data["stream"] = True
        self._add_images(s, data)

204
        res = http_request(
205
206
207
            self.base_url + "/generate",
            json=data,
            stream=True,
Lianmin Zheng's avatar
Lianmin Zheng committed
208
            api_key=self.api_key,
209
            verify=self.verify,
210
        )
211
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
212
213
        pos = 0

214
        for chunk in res.iter_lines(decode_unicode=False):
215
216
217
218
219
            chunk = chunk.decode("utf-8")
            if chunk and chunk.startswith("data:"):
                if chunk == "data: [DONE]":
                    break
                data = json.loads(chunk[5:].strip("\n"))
220
                chunk_text = data["text"][pos:]
Lianmin Zheng's avatar
Lianmin Zheng committed
221
                meta_info = data["meta_info"]
222
223
                pos += len(chunk_text)
                yield chunk_text, meta_info
Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
228
229

    def select(
        self,
        s: StreamExecutor,
        choices: List[str],
        temperature: float,
230
231
        choices_method: ChoicesSamplingMethod,
    ) -> ChoicesDecision:
Lianmin Zheng's avatar
Lianmin Zheng committed
232
233
234
235
        assert temperature <= 1e-5

        # Cache common prefix
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
236
237
        obj = self._generate_http_request(s, data)
        prompt_len = obj["meta_info"]["prompt_tokens"]
Lianmin Zheng's avatar
Lianmin Zheng committed
238
        logprob_start_len = max(prompt_len - 2, 0)  # For token healing
Lianmin Zheng's avatar
Lianmin Zheng committed
239
240
241
242

        # Compute logprob
        data = {
            "text": [s.text_ + c for c in choices],
243
244
245
246
            "sampling_params": {
                "max_new_tokens": 0,
                "temperature": 0,
            },
247
            "return_logprob": True,
248
            "return_text_in_logprobs": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
249
            "logprob_start_len": logprob_start_len,
Lianmin Zheng's avatar
Lianmin Zheng committed
250
        }
251
252
        obj = self._generate_http_request(s, data)

Liangsheng Yin's avatar
Liangsheng Yin committed
253
        normalized_prompt_logprobs = [
254
255
            r["meta_info"]["normalized_prompt_logprob"] for r in obj
        ]
256
257
        input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
        output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
Liangsheng Yin's avatar
Liangsheng Yin committed
258

259
260
261
262
        # Remove extra token if no token healing occurred
        for i in range(len(input_token_logprobs)):
            healed_token_str = input_token_logprobs[i][0][-1]
            if s.text_.endswith(healed_token_str):
Lianmin Zheng's avatar
Lianmin Zheng committed
263
                healed_token_logprob = input_token_logprobs[i][0][0]
264
265
266
267
268
269
                normalized_prompt_logprobs[i] = (
                    normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
                    - healed_token_logprob
                ) / (len(input_token_logprobs[i]) - 1)
                input_token_logprobs[i] = input_token_logprobs[i][1:]

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        # Compute unconditional logprobs if required
        if choices_method.requires_unconditional_logprobs:
            input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
            data = {
                "input_ids": input_ids,
                "sampling_params": {"max_new_tokens": 0},
                "return_logprob": True,
            }
            obj = self._generate_http_request(s, data)
            unconditional_token_logprobs = [
                r["meta_info"]["input_token_logprobs"] for r in obj
            ]
        else:
            unconditional_token_logprobs = None

        return choices_method(
            choices=choices,
            normalized_prompt_logprobs=normalized_prompt_logprobs,
            input_token_logprobs=input_token_logprobs,
            output_token_logprobs=output_token_logprobs,
            unconditional_token_logprobs=unconditional_token_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
291
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
292
293
294
295
296

    def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
        res = http_request(
            self.base_url + "/concate_and_append_request",
            json={"src_rids": src_rids, "dst_rid": dst_rid},
297
            api_key=self.api_key,
298
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
299
        )
300
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
301

302
303
304
305
306
307
308
309
310
311
312
    def _generate_http_request(self, s: StreamExecutor, data):
        self._add_images(s, data)
        res = http_request(
            self.base_url + "/generate",
            json=data,
            api_key=self.api_key,
            verify=self.verify,
        )
        self._assert_success(res)
        return res.json()

Lianmin Zheng's avatar
Lianmin Zheng committed
313
314
315
316
    def _add_images(self, s: StreamExecutor, data):
        if s.images_:
            assert len(s.images_) == 1, "Only support one image."
            data["image_data"] = s.images_[0][1]
317
318
319

    def _assert_success(self, res):
        if res.status_code != 200:
320
            raise RuntimeError(res.json())