runtime_endpoint.py 17.1 KB
Newer Older
1
import atexit
Lianmin Zheng's avatar
Lianmin Zheng committed
2
import json
3
import multiprocessing
4
import warnings
5
6
7
8
from typing import Dict, List, Optional, Union

import aiohttp
import requests
Lianmin Zheng's avatar
Lianmin Zheng committed
9
10

from sglang.global_config import global_config
Ying Sheng's avatar
Ying Sheng committed
11
from sglang.lang.backend.base_backend import BaseBackend
12
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
13
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
Lianmin Zheng's avatar
Lianmin Zheng committed
14
from sglang.lang.interpreter import StreamExecutor
15
16
17
18
19
20
21
from sglang.lang.ir import (
    REGEX_BOOL,
    REGEX_FLOAT,
    REGEX_INT,
    REGEX_STR,
    SglSamplingParams,
)
22
from sglang.utils import http_request
Lianmin Zheng's avatar
Lianmin Zheng committed
23
24
25


class RuntimeEndpoint(BaseBackend):
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
30
    def __init__(
        self,
        base_url: str,
        api_key: Optional[str] = None,
        verify: Optional[str] = None,
31
        chat_template_name: Optional[str] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
32
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
35
36
        super().__init__()
        self.support_concate_and_append = True

        self.base_url = base_url
37
        self.api_key = api_key
38
        self.verify = verify
Lianmin Zheng's avatar
Lianmin Zheng committed
39

40
        res = http_request(
41
            self.base_url + "/get_model_info",
42
            api_key=self.api_key,
43
            verify=self.verify,
44
        )
45
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
46
47
        self.model_info = res.json()

48
49
50
51
52
53
        if chat_template_name:
            self.chat_template = get_chat_template(chat_template_name)
        else:
            self.chat_template = get_chat_template_by_model_path(
                self.model_info["model_path"]
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
56
57

    def get_model_name(self):
        return self.model_info["model_path"]

Liangsheng Yin's avatar
Liangsheng Yin committed
58
59
60
    def flush_cache(self):
        res = http_request(
            self.base_url + "/flush_cache",
61
            api_key=self.api_key,
Liangsheng Yin's avatar
Liangsheng Yin committed
62
            verify=self.verify,
63
            method="POST",
Liangsheng Yin's avatar
Liangsheng Yin committed
64
        )
65
        self._assert_success(res)
Liangsheng Yin's avatar
Liangsheng Yin committed
66

67
    def get_server_info(self):
Liangsheng Yin's avatar
Liangsheng Yin committed
68
        res = http_request(
69
            self.base_url + "/get_server_info",
70
            api_key=self.api_key,
Liangsheng Yin's avatar
Liangsheng Yin committed
71
72
            verify=self.verify,
        )
73
        self._assert_success(res)
Liangsheng Yin's avatar
Liangsheng Yin committed
74
75
        return res.json()

Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
79
80
81
82
    def get_chat_template(self):
        return self.chat_template

    def cache_prefix(self, prefix_str: str):
        res = http_request(
            self.base_url + "/generate",
            json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
Lianmin Zheng's avatar
Lianmin Zheng committed
83
            api_key=self.api_key,
84
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
85
        )
86
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
87

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    def start_profile(self):
        res = http_request(
            self.base_url + "/start_profile",
            api_key=self.api_key,
            verify=self.verify,
        )
        self._assert_success(res)

    def stop_profile(self):
        res = http_request(
            self.base_url + "/stop_profile",
            api_key=self.api_key,
            verify=self.verify,
        )
        self._assert_success(res)

Lianmin Zheng's avatar
Lianmin Zheng committed
104
    def commit_lazy_operations(self, s: StreamExecutor):
105
106
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
        self._add_images(s, data)
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
        res = http_request(
            self.base_url + "/generate",
109
            json=data,
Lianmin Zheng's avatar
Lianmin Zheng committed
110
            api_key=self.api_key,
111
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
112
        )
113
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
114
115
116
117

    def fill_image(self, s: StreamExecutor):
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
        self._add_images(s, data)
118
        res = http_request(
119
120
            self.base_url + "/generate",
            json=data,
121
            api_key=self.api_key,
122
            verify=self.verify,
123
        )
124
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
        if sampling_params.dtype is None:
            return

        if sampling_params.stop == ():
            sampling_params.stop = []

        dtype_regex = None
        if sampling_params.dtype in ["int", int]:

            dtype_regex = REGEX_INT
            sampling_params.stop.extend([" ", "\n"])
        elif sampling_params.dtype in ["float", float]:

            dtype_regex = REGEX_FLOAT
            sampling_params.stop.extend([" ", "\n"])
        elif sampling_params.dtype in ["str", str]:

            dtype_regex = REGEX_STR
        elif sampling_params.dtype in ["bool", bool]:

            dtype_regex = REGEX_BOOL
        else:
            raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")

        if dtype_regex is not None and sampling_params.regex is not None:
            warnings.warn(
                f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
            )

        sampling_params.regex = dtype_regex

Lianmin Zheng's avatar
Lianmin Zheng committed
158
159
160
    def generate(
        self,
        s: StreamExecutor,
161
        sampling_params: SglSamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
162
    ):
163
164
165
166
167
168
169
170
171
        self._handle_dtype_to_regex(sampling_params)
        data = {
            "text": s.text_,
            "sampling_params": {
                "skip_special_tokens": global_config.skip_special_tokens_in_output,
                "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
                **sampling_params.to_srt_kwargs(),
            },
        }
Lianmin Zheng's avatar
Lianmin Zheng committed
172

173
174
175
176
177
178
        for item in [
            "return_logprob",
            "logprob_start_len",
            "top_logprobs_num",
            "return_text_in_logprobs",
        ]:
179
180
181
182
            value = getattr(sampling_params, item, None)
            if value is not None:
                data[item] = value

Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
        self._add_images(s, data)

185
        res = http_request(
186
187
            self.base_url + "/generate",
            json=data,
188
            api_key=self.api_key,
189
            verify=self.verify,
190
        )
191
192
        self._assert_success(res)

Lianmin Zheng's avatar
Lianmin Zheng committed
193
194
195
196
197
198
199
        obj = res.json()
        comp = obj["text"]
        return comp, obj["meta_info"]

    def generate_stream(
        self,
        s: StreamExecutor,
200
        sampling_params: SglSamplingParams,
Lianmin Zheng's avatar
Lianmin Zheng committed
201
    ):
202
203
204
205
206
207
208
209
210
211
        self._handle_dtype_to_regex(sampling_params)

        data = {
            "text": s.text_,
            "sampling_params": {
                "skip_special_tokens": global_config.skip_special_tokens_in_output,
                "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
                **sampling_params.to_srt_kwargs(),
            },
        }
Lianmin Zheng's avatar
Lianmin Zheng committed
212

213
214
215
216
217
218
        for item in [
            "return_logprob",
            "logprob_start_len",
            "top_logprobs_num",
            "return_text_in_logprobs",
        ]:
219
220
221
222
            value = getattr(sampling_params, item, None)
            if value is not None:
                data[item] = value

Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
225
        data["stream"] = True
        self._add_images(s, data)

226
        res = http_request(
227
228
229
            self.base_url + "/generate",
            json=data,
            stream=True,
Lianmin Zheng's avatar
Lianmin Zheng committed
230
            api_key=self.api_key,
231
            verify=self.verify,
232
        )
233
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
234
235
        pos = 0

236
        for chunk in res.iter_lines(decode_unicode=False):
237
238
239
240
241
            chunk = chunk.decode("utf-8")
            if chunk and chunk.startswith("data:"):
                if chunk == "data: [DONE]":
                    break
                data = json.loads(chunk[5:].strip("\n"))
242
                chunk_text = data["text"][pos:]
Lianmin Zheng's avatar
Lianmin Zheng committed
243
                meta_info = data["meta_info"]
244
245
                pos += len(chunk_text)
                yield chunk_text, meta_info
Lianmin Zheng's avatar
Lianmin Zheng committed
246
247
248
249
250
251

    def select(
        self,
        s: StreamExecutor,
        choices: List[str],
        temperature: float,
252
253
        choices_method: ChoicesSamplingMethod,
    ) -> ChoicesDecision:
Lianmin Zheng's avatar
Lianmin Zheng committed
254
255
256
257
        assert temperature <= 1e-5

        # Cache common prefix
        data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
258
259
        obj = self._generate_http_request(s, data)
        prompt_len = obj["meta_info"]["prompt_tokens"]
Lianmin Zheng's avatar
Lianmin Zheng committed
260
        logprob_start_len = max(prompt_len - 2, 0)  # For token healing
Lianmin Zheng's avatar
Lianmin Zheng committed
261
262
263
264

        # Compute logprob
        data = {
            "text": [s.text_ + c for c in choices],
265
266
267
268
            "sampling_params": {
                "max_new_tokens": 0,
                "temperature": 0,
            },
269
            "return_logprob": True,
270
            "return_text_in_logprobs": True,
Lianmin Zheng's avatar
Lianmin Zheng committed
271
            "logprob_start_len": logprob_start_len,
Lianmin Zheng's avatar
Lianmin Zheng committed
272
        }
273
274
        obj = self._generate_http_request(s, data)

275
276
        input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
        output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
277
278
279
280
        normalized_prompt_logprobs = [
            compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
            for r in obj
        ]
Liangsheng Yin's avatar
Liangsheng Yin committed
281

282
283
284
285
        # Remove extra token if no token healing occurred
        for i in range(len(input_token_logprobs)):
            healed_token_str = input_token_logprobs[i][0][-1]
            if s.text_.endswith(healed_token_str):
Lianmin Zheng's avatar
Lianmin Zheng committed
286
                healed_token_logprob = input_token_logprobs[i][0][0]
287
288
289
290
291
292
                normalized_prompt_logprobs[i] = (
                    normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
                    - healed_token_logprob
                ) / (len(input_token_logprobs[i]) - 1)
                input_token_logprobs[i] = input_token_logprobs[i][1:]

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        # Compute unconditional logprobs if required
        if choices_method.requires_unconditional_logprobs:
            input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
            data = {
                "input_ids": input_ids,
                "sampling_params": {"max_new_tokens": 0},
                "return_logprob": True,
            }
            obj = self._generate_http_request(s, data)
            unconditional_token_logprobs = [
                r["meta_info"]["input_token_logprobs"] for r in obj
            ]
        else:
            unconditional_token_logprobs = None

        return choices_method(
            choices=choices,
            normalized_prompt_logprobs=normalized_prompt_logprobs,
            input_token_logprobs=input_token_logprobs,
            output_token_logprobs=output_token_logprobs,
            unconditional_token_logprobs=unconditional_token_logprobs,
Liangsheng Yin's avatar
Liangsheng Yin committed
314
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
315
316
317
318
319

    def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
        res = http_request(
            self.base_url + "/concate_and_append_request",
            json={"src_rids": src_rids, "dst_rid": dst_rid},
320
            api_key=self.api_key,
321
            verify=self.verify,
Lianmin Zheng's avatar
Lianmin Zheng committed
322
        )
323
        self._assert_success(res)
Lianmin Zheng's avatar
Lianmin Zheng committed
324

325
326
327
328
329
330
331
332
333
334
335
    def _generate_http_request(self, s: StreamExecutor, data):
        self._add_images(s, data)
        res = http_request(
            self.base_url + "/generate",
            json=data,
            api_key=self.api_key,
            verify=self.verify,
        )
        self._assert_success(res)
        return res.json()

Lianmin Zheng's avatar
Lianmin Zheng committed
336
337
338
339
    def _add_images(self, s: StreamExecutor, data):
        if s.images_:
            assert len(s.images_) == 1, "Only support one image."
            data["image_data"] = s.images_[0][1]
340
341
342

    def _assert_success(self, res):
        if res.status_code != 200:
fzyzcjy's avatar
fzyzcjy committed
343
344
345
346
347
            try:
                content = res.json()
            except json.JSONDecodeError:
                content = res.text
            raise RuntimeError(content)
348
349
350
351


def compute_normalized_prompt_logprobs(input_logprobs):
    values = [x[0] for x in input_logprobs if x[0]]
Lianmin Zheng's avatar
Lianmin Zheng committed
352
    return sum(values) / len(values)
353
354
355
356
357
358


class Runtime:
    """
    A wrapper for the HTTP server.
    This is used for launching the server in a python program without
359
    using the command line interface.
360
361

    It is mainly used for the frontend language.
362
    You should use the Engine class if you want to do normal offline processing without the frontend language.
363
364
365
366
367
368
369
370
371
    """

    def __init__(
        self,
        log_level: str = "error",
        *args,
        **kwargs,
    ):
        """See the arguments in server_args.py::ServerArgs"""
372
373
        # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
        # client code without installing SRT server and its dependency if they want.
374
        from sglang.srt.entrypoints.http_server import launch_server
375
376
        from sglang.srt.server_args import ServerArgs
        from sglang.srt.utils import is_port_available
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392

        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)

        # Pre-allocate ports
        for port in range(self.server_args.port, 40000):
            if is_port_available(port):
                break
        self.server_args.port = port

        self.url = self.server_args.url()
        self.generate_url = self.url + "/generate"

        # NOTE: We store pid instead of proc to fix some issues during __delete__
        self.pid = None
        pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)

393
394
        ctx = multiprocessing.get_context("spawn")
        proc = ctx.Process(
395
396
397
398
399
400
401
            target=launch_server,
            args=(self.server_args, pipe_writer),
        )
        proc.start()
        pipe_writer.close()
        self.pid = proc.pid

402
403
404
405
        # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

        # TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
406
407
408
409
410
411
412
413
414
415
416
417
418
419
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

        if init_state != "ready":
            self.shutdown()
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
420
421
        from sglang.srt.utils import kill_process_tree

422
423
424
425
        if self.pid is not None:
            kill_process_tree(self.pid)
            self.pid = None

426
427
428
429
430
431
    def start_profile(self):
        self.endpoint.start_profile()

    def stop_profile(self):
        self.endpoint.stop_profile()

432
433
434
435
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

    def get_tokenizer(self):
436
        from sglang.srt.utils.hf_transformers_utils import get_tokenizer
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
            revision=self.server_args.revision,
        )

    async def async_generate(
        self,
        prompt: str,
        sampling_params: Optional[Dict] = None,
    ):
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
                        if "text" in data:
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data

    add_request = async_generate

    def generate(
        self,
        prompt: Union[str, List[str]],
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
            "logprob_start_len": logprob_start_len,
            "top_logprobs_num": top_logprobs_num,
            "lora_path": lora_path,
        }
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

    def encode(
        self,
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
    ):
        json_data = {"text": prompt}
        response = requests.post(self.url + "/encode", json=json_data)
        return json.dumps(response.json())

    async def get_server_info(self):
        async with aiohttp.ClientSession() as session:
            async with session.get(f"{self.url}/get_server_info") as response:
                if response.status == 200:
                    return await response.json()
                else:
                    error_data = await response.json()
                    raise RuntimeError(
                        f"Failed to get server info. {error_data['error']['message']}"
                    )

    def __del__(self):
        self.shutdown()