server.py 19.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
22
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
23
import json
24
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
25
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
26
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
import threading
import time
29
from http import HTTPStatus
30
from typing import Dict, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
# Fix a bug of Python threading
33
34
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
35
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
import requests
import uvicorn
import uvloop
39
from fastapi import FastAPI, File, Form, Request, UploadFile
40
from fastapi.middleware.cors import CORSMiddleware
41
from fastapi.responses import JSONResponse, Response, StreamingResponse
Liangsheng Yin's avatar
Liangsheng Yin committed
42

Ying Sheng's avatar
Ying Sheng committed
43
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Liangsheng Yin's avatar
Liangsheng Yin committed
44
from sglang.srt.constrained import disable_cache
Ying Sheng's avatar
Ying Sheng committed
45
from sglang.srt.hf_transformers_utils import get_tokenizer
46
from sglang.srt.managers.controller_multi import (
47
48
    start_controller_process as start_controller_process_multi,
)
49
50
from sglang.srt.managers.controller_single import launch_tp_servers
from sglang.srt.managers.controller_single import (
51
52
    start_controller_process as start_controller_process_single,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
53
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
54
55
56
57
58
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
    UpdateWeightReqInput,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
59
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
60
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
61
    load_chat_template_for_openai_api,
62
    v1_batches,
63
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
    v1_chat_completions,
    v1_completions,
66
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
67
    v1_embeddings,
68
69
70
71
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
72
)
Mingyi's avatar
Mingyi committed
73
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
74
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
75
from sglang.srt.utils import (
76
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
    allocate_init_ports,
    assert_pkg_version,
79
    configure_logger,
80
    enable_show_time_cost,
81
    kill_child_process,
82
    maybe_set_triton_cache_manager,
83
84
    prepare_model,
    prepare_tokenizer,
85
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
86
)
87
88
from sglang.utils import get_exception_traceback

89
90
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
93

Lianmin Zheng's avatar
Lianmin Zheng committed
94
95
96
app = FastAPI()
tokenizer_manager = None

97
98
99
100
101
102
103
104
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Lianmin Zheng's avatar
Lianmin Zheng committed
105

106
107
108
109
110
111
112
113
114
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
115
116
117
118
119
120
121
122
123
124
125
126
    gri = GenerateReqInput(
        text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
    )
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
129
130
@app.get("/get_model_info")
async def get_model_info():
    result = {
        "model_path": tokenizer_manager.model_path,
131
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
134
    }
    return result

Cody Yu's avatar
Cody Yu committed
135

Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
138
139
140
@app.get("/get_server_args")
async def get_server_args():
    return dataclasses.asdict(tokenizer_manager.server_args)


Liangsheng Yin's avatar
Liangsheng Yin committed
141
142
@app.get("/flush_cache")
async def flush_cache():
143
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
144
    return Response(
145
146
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
147
148
149
150
        status_code=200,
    )


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):

    success, message = await tokenizer_manager.update_weights(obj, request)
    content = {"message": message, "success": str(success)}
    if success:
        return JSONResponse(
            content,
            status_code=HTTPStatus.OK,
        )
    else:
        return JSONResponse(
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


168
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
169
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
170
    if obj.stream:
171

Lianmin Zheng's avatar
Lianmin Zheng committed
172
        async def stream_results():
173
174
175
176
177
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
                    yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            except ValueError as e:
                out = {"error": {"message": str(e)}}
Lianmin Zheng's avatar
Lianmin Zheng committed
178
179
180
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"

181
182
183
184
185
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
186
187
188
189
190
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
191
192
193
194
            return JSONResponse(
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
195

Ying Sheng's avatar
Ying Sheng committed
196
197
198
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
        return JSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


Lianmin Zheng's avatar
Lianmin Zheng committed
215
@app.post("/v1/completions")
216
217
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
218
219


Cody Yu's avatar
Cody Yu committed
220
@app.post("/v1/chat/completions")
221
222
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
223

Lianmin Zheng's avatar
Lianmin Zheng committed
224

Ying Sheng's avatar
Ying Sheng committed
225
226
227
228
229
230
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


231
232
233
234
235
236
237
238
239
240
@app.get("/v1/models")
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


241
242
243
244
245
246
247
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


248
249
250
251
252
253
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


254
255
256
257
258
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


259
260
261
262
263
264
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


zhyncs's avatar
zhyncs committed
282
283
284
285
def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
Mingyi's avatar
Mingyi committed
286
    """Launch an HTTP server."""
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288
    global tokenizer_manager

289
    configure_logger(server_args)
290

291
292
    server_args.check_server_args()
    _set_envs_and_config(server_args)
293

294
    # Allocate ports for inter-process communications
Lianmin Zheng's avatar
Lianmin Zheng committed
295
    server_args.port, server_args.additional_ports = allocate_init_ports(
296
297
298
        server_args.port,
        server_args.additional_ports,
        server_args.dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
299
    )
300
    ports = server_args.additional_ports
Lianmin Zheng's avatar
Lianmin Zheng committed
301
    port_args = PortArgs(
302
        tokenizer_port=ports[0],
Mingyi's avatar
Mingyi committed
303
        controller_port=ports[1],
304
        detokenizer_port=ports[2],
Mingyi's avatar
Mingyi committed
305
        nccl_ports=ports[3:],
Lianmin Zheng's avatar
Lianmin Zheng committed
306
    )
307
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
308

309
310
311
312
    # Use model from www.modelscope.cn, first download the model.
    server_args.model_path = prepare_model(server_args.model_path)
    server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)

313
    # Launch processes for multi-node tensor parallelism
314
315
316
317
318
319
320
    if server_args.nnodes > 1 and server_args.node_rank != 0:
        tp_size_local = server_args.tp_size // server_args.nnodes
        gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
        tp_rank_range = list(
            range(
                server_args.node_rank * tp_size_local,
                (server_args.node_rank + 1) * tp_size_local,
zhyncs's avatar
zhyncs committed
321
            )
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        )
        procs = launch_tp_servers(
            gpu_ids,
            tp_rank_range,
            server_args,
            ports[3],
        )

        try:
            for p in procs:
                p.join()
        finally:
            kill_child_process(os.getpid(), including_parent=False)
            return
336

Lianmin Zheng's avatar
Lianmin Zheng committed
337
    # Launch processes
Mingyi's avatar
Mingyi committed
338
    pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
Lianmin Zheng's avatar
Lianmin Zheng committed
339

340
    if server_args.dp_size == 1:
341
        start_controller_process = start_controller_process_single
342
    else:
343
        start_controller_process = start_controller_process_multi
Mingyi's avatar
Mingyi committed
344
    proc_controller = mp.Process(
345
        target=start_controller_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
346
        args=(server_args, port_args, pipe_controller_writer),
Lianmin Zheng's avatar
Lianmin Zheng committed
347
    )
Mingyi's avatar
Mingyi committed
348
    proc_controller.start()
349

350
    pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
353
354
355
356
357
358
359
360
    proc_detoken = mp.Process(
        target=start_detokenizer_process,
        args=(
            server_args,
            port_args,
            pipe_detoken_writer,
        ),
    )
    proc_detoken.start()

361
362
363
364
    tokenizer_manager = TokenizerManager(server_args, port_args)
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)

Lianmin Zheng's avatar
Lianmin Zheng committed
365
    # Wait for the model to finish loading
Mingyi's avatar
Mingyi committed
366
    controller_init_state = pipe_controller_reader.recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
    detoken_init_state = pipe_detoken_reader.recv()

Mingyi's avatar
Mingyi committed
369
370
    if controller_init_state != "init ok" or detoken_init_state != "init ok":
        proc_controller.kill()
Lianmin Zheng's avatar
Lianmin Zheng committed
371
        proc_detoken.kill()
372
373
374
375
        raise RuntimeError(
            "Initialization failed. "
            f"controller_init_state: {controller_init_state}, "
            f"detoken_init_state: {detoken_init_state}"
Yuanhan Zhang's avatar
Yuanhan Zhang committed
376
        )
Mingyi's avatar
Mingyi committed
377
    assert proc_controller.is_alive() and proc_detoken.is_alive()
Lianmin Zheng's avatar
Lianmin Zheng committed
378

379
380
381
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
382

383
    # Send a warmup request
zhyncs's avatar
zhyncs committed
384
    t = threading.Thread(
385
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
zhyncs's avatar
zhyncs committed
386
    )
387
    t.start()
388

389
    try:
390
        # Listen for requests
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
393
394
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
395
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
396
397
398
            timeout_keep_alive=5,
            loop="uvloop",
        )
399
400
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
401
402


403
404
405
406
407
408
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
409
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427

    # Set ulimit
    set_ulimit()

    # Enable show time cost for debugging
    if server_args.show_time_cost:
        enable_show_time_cost()

    # Disable disk cache
    if server_args.disable_disk_cache:
        disable_cache()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
428
    if server_args.attention_backend == "flashinfer":
429
430
        assert_pkg_version(
            "flashinfer",
431
            "0.1.6",
432
433
434
435
436
437
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )


438
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
Mingyi's avatar
Mingyi committed
439
440
441
    headers = {}
    url = server_args.url()
    if server_args.api_key:
442
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
443
444

    # Wait until the server is launched
445
    success = False
Mingyi's avatar
Mingyi committed
446
    for _ in range(120):
447
        time.sleep(1)
Mingyi's avatar
Mingyi committed
448
        try:
449
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
450
            assert res.status_code == 200, f"{res=}, {res.text=}"
451
            success = True
Mingyi's avatar
Mingyi committed
452
            break
453
        except (AssertionError, requests.exceptions.RequestException):
454
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
455
456
            pass

457
458
459
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
460
461
462
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
463

464
465
    model_info = res.json()

Mingyi's avatar
Mingyi committed
466
    # Send a warmup request
467
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
468
    max_new_tokens = 8 if model_info["is_generation"] else 1
469
470
471
472
473
474
475
476
477
478
479
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
480
481
482
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
483
                url + request_name,
484
                json=json_data,
Mingyi's avatar
Mingyi committed
485
486
487
                headers=headers,
                timeout=600,
            )
488
            assert res.status_code == 200, f"{res}"
489
    except Exception:
490
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
491
        if pipe_finish_writer is not None:
492
            pipe_finish_writer.send(last_traceback)
493
494
495
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
Mingyi's avatar
Mingyi committed
496
497
498
499
500
501

    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
        pipe_finish_writer.send("init ok")


Lianmin Zheng's avatar
Lianmin Zheng committed
502
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
503
504
505
506
507
508
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
509
510
    def __init__(
        self,
511
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
514
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
515
        """See the arguments in server_args.py::ServerArgs"""
516
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
517
518
519

        # Pre-allocate ports
        self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
520
521
            self.server_args.port,
            self.server_args.additional_ports,
522
            self.server_args.dp_size,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
523
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
524

Ying Sheng's avatar
Ying Sheng committed
525
526
527
528
        self.url = self.server_args.url()
        self.generate_url = (
            f"http://{self.server_args.host}:{self.server_args.port}/generate"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
529
530
531

        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
532

Yuanhan Zhang's avatar
Yuanhan Zhang committed
533
534
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
535
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
536
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
537
        proc.start()
538
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
539
540
        self.pid = proc.pid

541
542
543
544
545
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

Lianmin Zheng's avatar
Lianmin Zheng committed
546
547
        if init_state != "init ok":
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
548
549
550
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
551
552
553
554
555

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
556
            kill_child_process(self.pid)
Lianmin Zheng's avatar
Lianmin Zheng committed
557
558
            self.pid = None

559
560
561
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
562
563
564
565
566
567
568
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

569
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
570
571
        self,
        prompt: str,
572
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
573
    ):
574
575
576
577
578
579
580
581
582
583
584
585
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
586
587
588
589
590
591
592
593
594
595
596
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
597
598
599
600
601
602
603
                        if hasattr(data, "text"):
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
604

605
606
607
608
    add_request = async_generate

    def generate(
        self,
609
        prompt: Union[str, List[str]],
610
611
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
612
        logprob_start_len: Optional[Union[List[int], int]] = None,
613
        top_logprobs_num: Optional[Union[List[int], int]] = None,
614
        lora_path: Optional[List[Optional[str]]] = None,
615
616
617
618
619
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
620
            "logprob_start_len": logprob_start_len,
621
            "top_logprobs_num": top_logprobs_num,
622
            "lora_path": lora_path,
623
        }
624
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
625
626
627
628
629
630
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

631
632
    def encode(
        self,
633
        prompt: Union[str, List[str]],
634
635
636
637
638
639
640
641
642
643
    ):
        json_data = {
            "text": prompt,
        }
        response = requests.post(
            self.url + "/encode",
            json=json_data,
        )
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
644
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
645
        self.shutdown()