server.py 19.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
22
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
23
import json
24
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
25
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
26
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
import threading
import time
29
from http import HTTPStatus
30
from typing import Dict, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
# Fix a bug of Python threading
33
34
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
35
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
import requests
import uvicorn
import uvloop
39
from fastapi import FastAPI, File, Form, Request, UploadFile
40
from fastapi.middleware.cors import CORSMiddleware
41
from fastapi.responses import JSONResponse, Response, StreamingResponse
Liangsheng Yin's avatar
Liangsheng Yin committed
42

Ying Sheng's avatar
Ying Sheng committed
43
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.srt.constrained import disable_cache
Ying Sheng's avatar
Ying Sheng committed
45
from sglang.srt.hf_transformers_utils import get_tokenizer
46
from sglang.srt.managers.controller_multi import (
47
48
    start_controller_process as start_controller_process_multi,
)
49
50
from sglang.srt.managers.controller_single import launch_tp_servers
from sglang.srt.managers.controller_single import (
51
52
    start_controller_process as start_controller_process_single,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
53
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
54
55
56
57
58
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
    UpdateWeightReqInput,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
59
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
60
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
61
    load_chat_template_for_openai_api,
62
    v1_batches,
63
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
    v1_chat_completions,
    v1_completions,
66
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
67
    v1_embeddings,
68
69
70
71
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
72
)
Mingyi's avatar
Mingyi committed
73
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
74
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
75
from sglang.srt.utils import (
76
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
    allocate_init_ports,
    assert_pkg_version,
79
    configure_logger,
80
    enable_show_time_cost,
81
    kill_child_process,
82
    maybe_set_triton_cache_manager,
83
84
    prepare_model,
    prepare_tokenizer,
85
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
86
)
87
88
from sglang.utils import get_exception_traceback

89
90
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
93

Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96
app = FastAPI()
tokenizer_manager = None

97
98
99
100
101
102
103
104
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Lianmin Zheng's avatar
Lianmin Zheng committed
105

106
107
108
109
110
111
112
113
114
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
115
116
117
118
119
120
121
122
123
124
125
126
    gri = GenerateReqInput(
        text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
    )
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
129
130
@app.get("/get_model_info")
async def get_model_info():
    result = {
        "model_path": tokenizer_manager.model_path,
131
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
    }
    return result

Cody Yu's avatar
Cody Yu committed
135

Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
138
139
140
@app.get("/get_server_args")
async def get_server_args():
    return dataclasses.asdict(tokenizer_manager.server_args)


Liangsheng Yin's avatar
Liangsheng Yin committed
141
142
@app.get("/flush_cache")
async def flush_cache():
143
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
144
    return Response(
145
146
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
147
148
149
150
        status_code=200,
    )


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):

    success, message = await tokenizer_manager.update_weights(obj, request)
    content = {"message": message, "success": str(success)}
    if success:
        return JSONResponse(
            content,
            status_code=HTTPStatus.OK,
        )
    else:
        return JSONResponse(
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


168
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
169
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
170
    if obj.stream:
171

Lianmin Zheng's avatar
Lianmin Zheng committed
172
        async def stream_results():
173
174
175
176
177
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
                    yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            except ValueError as e:
                out = {"error": {"message": str(e)}}
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"

181
182
183
184
185
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
186
187
188
189
190
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
191
192
193
194
            return JSONResponse(
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
195

Ying Sheng's avatar
Ying Sheng committed
196
197
198
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
        return JSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


Lianmin Zheng's avatar
Lianmin Zheng committed
215
@app.post("/v1/completions")
216
217
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
218
219


Cody Yu's avatar
Cody Yu committed
220
@app.post("/v1/chat/completions")
221
222
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
223

Lianmin Zheng's avatar
Lianmin Zheng committed
224

Ying Sheng's avatar
Ying Sheng committed
225
226
227
228
229
230
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


231
232
233
234
235
236
237
238
239
240
@app.get("/v1/models")
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


241
242
243
244
245
246
247
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


248
249
250
251
252
253
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


254
255
256
257
258
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


259
260
261
262
263
264
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


zhyncs's avatar
zhyncs committed
282
283
284
285
def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
Mingyi's avatar
Mingyi committed
286
    """Launch an HTTP server."""
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
    global tokenizer_manager

289
    configure_logger(server_args)
290

291
292
    server_args.check_server_args()
    _set_envs_and_config(server_args)
293

294
    # Allocate ports for inter-process communications
Lianmin Zheng's avatar
Lianmin Zheng committed
295
    server_args.port, server_args.additional_ports = allocate_init_ports(
296
297
298
        server_args.port,
        server_args.additional_ports,
        server_args.dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
299
    )
300
    ports = server_args.additional_ports
Lianmin Zheng's avatar
Lianmin Zheng committed
301
    port_args = PortArgs(
302
        tokenizer_port=ports[0],
Mingyi's avatar
Mingyi committed
303
        controller_port=ports[1],
304
        detokenizer_port=ports[2],
Mingyi's avatar
Mingyi committed
305
        nccl_ports=ports[3:],
Lianmin Zheng's avatar
Lianmin Zheng committed
306
    )
307
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
308

309
310
311
312
    # Use model from www.modelscope.cn, first download the model.
    server_args.model_path = prepare_model(server_args.model_path)
    server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)

313
    # Launch processes for multi-node tensor parallelism
314
315
316
317
318
319
320
    if server_args.nnodes > 1 and server_args.node_rank != 0:
        tp_size_local = server_args.tp_size // server_args.nnodes
        gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
        tp_rank_range = list(
            range(
                server_args.node_rank * tp_size_local,
                (server_args.node_rank + 1) * tp_size_local,
zhyncs's avatar
zhyncs committed
321
            )
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        )
        procs = launch_tp_servers(
            gpu_ids,
            tp_rank_range,
            server_args,
            ports[3],
        )

        try:
            for p in procs:
                p.join()
        finally:
            kill_child_process(os.getpid(), including_parent=False)
            return
336

Lianmin Zheng's avatar
Lianmin Zheng committed
337
    # Launch processes
Lianmin Zheng's avatar
Lianmin Zheng committed
338
    tokenizer_manager = TokenizerManager(server_args, port_args)
339
340
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
Mingyi's avatar
Mingyi committed
341
    pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
Lianmin Zheng's avatar
Lianmin Zheng committed
342
343
    pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)

344
    if server_args.dp_size == 1:
345
        start_controller_process = start_controller_process_single
346
    else:
347
        start_controller_process = start_controller_process_multi
348

Mingyi's avatar
Mingyi committed
349
    proc_controller = mp.Process(
350
        target=start_controller_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
351
        args=(server_args, port_args, pipe_controller_writer),
Lianmin Zheng's avatar
Lianmin Zheng committed
352
    )
Mingyi's avatar
Mingyi committed
353
    proc_controller.start()
354

Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
357
358
359
360
361
362
363
364
365
    proc_detoken = mp.Process(
        target=start_detokenizer_process,
        args=(
            server_args,
            port_args,
            pipe_detoken_writer,
        ),
    )
    proc_detoken.start()

    # Wait for the model to finish loading
Mingyi's avatar
Mingyi committed
366
    controller_init_state = pipe_controller_reader.recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
    detoken_init_state = pipe_detoken_reader.recv()

Mingyi's avatar
Mingyi committed
369
370
    if controller_init_state != "init ok" or detoken_init_state != "init ok":
        proc_controller.kill()
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        proc_detoken.kill()
372
373
374
375
        raise RuntimeError(
            "Initialization failed. "
            f"controller_init_state: {controller_init_state}, "
            f"detoken_init_state: {detoken_init_state}"
Yuanhan Zhang's avatar
Yuanhan Zhang committed
376
        )
Mingyi's avatar
Mingyi committed
377
    assert proc_controller.is_alive() and proc_detoken.is_alive()
Lianmin Zheng's avatar
Lianmin Zheng committed
378

379
380
381
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
382

383
    # Send a warmup request
zhyncs's avatar
zhyncs committed
384
    t = threading.Thread(
385
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
zhyncs's avatar
zhyncs committed
386
    )
387
    t.start()
388

389
    try:
390
        # Listen for requests
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
393
394
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
395
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
396
397
398
            timeout_keep_alive=5,
            loop="uvloop",
        )
399
400
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
401
402


403
404
405
406
407
408
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
409
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430

    # Set ulimit
    set_ulimit()

    # Enable show time cost for debugging
    if server_args.show_time_cost:
        enable_show_time_cost()

    # Disable disk cache
    if server_args.disable_disk_cache:
        disable_cache()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
    if not server_args.disable_flashinfer:
        assert_pkg_version(
            "flashinfer",
431
            "0.1.6",
432
433
434
435
436
437
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )


438
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
Mingyi's avatar
Mingyi committed
439
440
441
    headers = {}
    url = server_args.url()
    if server_args.api_key:
442
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
443
444

    # Wait until the server is launched
445
    success = False
Mingyi's avatar
Mingyi committed
446
    for _ in range(120):
447
        time.sleep(1)
Mingyi's avatar
Mingyi committed
448
        try:
449
450
451
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
            assert res.status_code == 200, f"{res}"
            success = True
Mingyi's avatar
Mingyi committed
452
            break
453
454
        except (AssertionError, requests.exceptions.RequestException) as e:
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
455
            pass
456
    model_info = res.json()
Mingyi's avatar
Mingyi committed
457

458
459
460
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
461
462
463
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
464

Mingyi's avatar
Mingyi committed
465
    # Send a warmup request
466
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
467
    max_new_tokens = 8 if model_info["is_generation"] else 1
468
469
470
471
472
473
474
475
476
477
478
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
479
480
481
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
482
                url + request_name,
483
                json=json_data,
Mingyi's avatar
Mingyi committed
484
485
486
                headers=headers,
                timeout=600,
            )
487
            assert res.status_code == 200, f"{res}"
488
    except Exception:
489
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
490
        if pipe_finish_writer is not None:
491
            pipe_finish_writer.send(last_traceback)
492
493
494
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
Mingyi's avatar
Mingyi committed
495
496
497
498
499
500

    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
        pipe_finish_writer.send("init ok")


Lianmin Zheng's avatar
Lianmin Zheng committed
501
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
502
503
504
505
506
507
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
508
509
    def __init__(
        self,
510
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
511
512
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
513
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
514
        """See the arguments in server_args.py::ServerArgs"""
515
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
516
517
518

        # Pre-allocate ports
        self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
519
520
            self.server_args.port,
            self.server_args.additional_ports,
521
            self.server_args.dp_size,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
522
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
523

Ying Sheng's avatar
Ying Sheng committed
524
525
526
527
        self.url = self.server_args.url()
        self.generate_url = (
            f"http://{self.server_args.host}:{self.server_args.port}/generate"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
528
529
530

        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
531

Yuanhan Zhang's avatar
Yuanhan Zhang committed
532
533
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
534
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
535
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
536
        proc.start()
537
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
538
539
        self.pid = proc.pid

540
541
542
543
544
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

Lianmin Zheng's avatar
Lianmin Zheng committed
545
546
        if init_state != "init ok":
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
547
548
549
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
550
551
552
553
554

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
555
            kill_child_process(self.pid)
Lianmin Zheng's avatar
Lianmin Zheng committed
556
557
            self.pid = None

558
559
560
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
561
562
563
564
565
566
567
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

568
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
569
570
        self,
        prompt: str,
571
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
572
    ):
573
574
575
576
577
578
579
580
581
582
583
584
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
585
586
587
588
589
590
591
592
593
594
595
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
596
597
598
599
600
601
602
                        if hasattr(data, "text"):
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
603

604
605
606
607
    add_request = async_generate

    def generate(
        self,
608
        prompt: Union[str, List[str]],
609
610
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
611
        logprob_start_len: Optional[Union[List[int], int]] = None,
612
613
614
615
616
617
        top_logprobs_num: Optional[Union[List[int], int]] = None,
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
618
            "logprob_start_len": logprob_start_len,
619
620
621
622
623
624
625
626
            "top_logprobs_num": top_logprobs_num,
        }
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

627
628
    def encode(
        self,
629
        prompt: Union[str, List[str]],
630
631
632
633
634
635
636
637
638
639
    ):
        json_data = {
            "text": prompt,
        }
        response = requests.post(
            self.url + "/encode",
            json=json_data,
        )
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
640
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
641
        self.shutdown()