"vscode:/vscode.git/clone" did not exist on "8bd9aa11bce4ec8469fd2e0ee34fdc5658faba2d"
server.py 19.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
Liangsheng Yin's avatar
Liangsheng Yin committed
22
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
23
import json
24
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
25
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
26
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
import threading
import time
29
from http import HTTPStatus
30
from typing import Dict, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
31

32
# Fix a bug of Python threading
33
34
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
35
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
36
37
38
import requests
import uvicorn
import uvloop
39
from fastapi import FastAPI, File, Form, Request, UploadFile
40
from fastapi.responses import JSONResponse, Response, StreamingResponse
Liangsheng Yin's avatar
Liangsheng Yin committed
41

Ying Sheng's avatar
Ying Sheng committed
42
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Liangsheng Yin's avatar
Liangsheng Yin committed
43
from sglang.srt.constrained import disable_cache
Ying Sheng's avatar
Ying Sheng committed
44
from sglang.srt.hf_transformers_utils import get_tokenizer
45
from sglang.srt.managers.controller_multi import (
46
47
    start_controller_process as start_controller_process_multi,
)
48
49
from sglang.srt.managers.controller_single import launch_tp_servers
from sglang.srt.managers.controller_single import (
50
51
    start_controller_process as start_controller_process_single,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
52
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
53
54
55
56
57
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
    UpdateWeightReqInput,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
59
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
60
    load_chat_template_for_openai_api,
61
    v1_batches,
Liangsheng Yin's avatar
Liangsheng Yin committed
62
63
    v1_chat_completions,
    v1_completions,
64
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
65
    v1_embeddings,
66
67
68
69
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
70
)
Mingyi's avatar
Mingyi committed
71
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
72
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
73
from sglang.srt.utils import (
74
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
    allocate_init_ports,
    assert_pkg_version,
77
    enable_show_time_cost,
78
    kill_child_process,
79
    maybe_set_triton_cache_manager,
80
81
    prepare_model,
    prepare_tokenizer,
82
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
83
)
84
85
from sglang.utils import get_exception_traceback

86
87
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
90

Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
93
94
app = FastAPI()
tokenizer_manager = None


95
96
97
98
99
100
101
102
103
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
104
105
106
107
108
109
110
111
112
113
114
115
    gri = GenerateReqInput(
        text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
    )
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
@app.get("/get_model_info")
async def get_model_info():
    result = {
        "model_path": tokenizer_manager.model_path,
120
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
121
122
123
    }
    return result

Cody Yu's avatar
Cody Yu committed
124

Liangsheng Yin's avatar
Liangsheng Yin committed
125
126
127
128
129
@app.get("/get_server_args")
async def get_server_args():
    return dataclasses.asdict(tokenizer_manager.server_args)


Liangsheng Yin's avatar
Liangsheng Yin committed
130
131
@app.get("/flush_cache")
async def flush_cache():
132
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
133
    return Response(
134
135
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
136
137
138
139
        status_code=200,
    )


140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):

    success, message = await tokenizer_manager.update_weights(obj, request)
    content = {"message": message, "success": str(success)}
    if success:
        return JSONResponse(
            content,
            status_code=HTTPStatus.OK,
        )
    else:
        return JSONResponse(
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


157
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
158
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
159
    if obj.stream:
160

Lianmin Zheng's avatar
Lianmin Zheng committed
161
        async def stream_results():
162
163
164
165
166
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
                    yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            except ValueError as e:
                out = {"error": {"message": str(e)}}
Lianmin Zheng's avatar
Lianmin Zheng committed
167
168
169
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"

170
171
172
173
174
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
175
176
177
178
179
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
180
181
182
183
            return JSONResponse(
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
184

Ying Sheng's avatar
Ying Sheng committed
185
186
187
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
        return JSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


Lianmin Zheng's avatar
Lianmin Zheng committed
204
@app.post("/v1/completions")
205
206
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
207
208


Cody Yu's avatar
Cody Yu committed
209
@app.post("/v1/chat/completions")
210
211
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
212

Lianmin Zheng's avatar
Lianmin Zheng committed
213

Ying Sheng's avatar
Ying Sheng committed
214
215
216
217
218
219
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


220
221
222
223
224
225
226
227
228
229
@app.get("/v1/models")
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


230
231
232
233
234
235
236
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


237
238
239
240
241
242
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


zhyncs's avatar
zhyncs committed
265
266
267
268
269
def launch_server(
    server_args: ServerArgs,
    model_overide_args: Optional[dict] = None,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
Mingyi's avatar
Mingyi committed
270
    """Launch an HTTP server."""
Lianmin Zheng's avatar
Lianmin Zheng committed
271
272
    global tokenizer_manager

273
274
275
276
277
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format="%(message)s",
    )

278
279
    server_args.check_server_args()
    _set_envs_and_config(server_args)
280

Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
    # Allocate ports
    server_args.port, server_args.additional_ports = allocate_init_ports(
283
284
285
        server_args.port,
        server_args.additional_ports,
        server_args.dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
286
    )
287
    ports = server_args.additional_ports
Lianmin Zheng's avatar
Lianmin Zheng committed
288
    port_args = PortArgs(
289
        tokenizer_port=ports[0],
Mingyi's avatar
Mingyi committed
290
        controller_port=ports[1],
291
        detokenizer_port=ports[2],
Mingyi's avatar
Mingyi committed
292
        nccl_ports=ports[3:],
Lianmin Zheng's avatar
Lianmin Zheng committed
293
    )
294
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
295

296
297
298
299
    # Use model from www.modelscope.cn, first download the model.
    server_args.model_path = prepare_model(server_args.model_path)
    server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)

300
    # Launch processes for multi-node tensor parallelism
301
302
303
304
305
306
307
    if server_args.nnodes > 1 and server_args.node_rank != 0:
        tp_size_local = server_args.tp_size // server_args.nnodes
        gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
        tp_rank_range = list(
            range(
                server_args.node_rank * tp_size_local,
                (server_args.node_rank + 1) * tp_size_local,
zhyncs's avatar
zhyncs committed
308
            )
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
        )
        procs = launch_tp_servers(
            gpu_ids,
            tp_rank_range,
            server_args,
            ports[3],
            model_overide_args,
        )

        try:
            for p in procs:
                p.join()
        finally:
            kill_child_process(os.getpid(), including_parent=False)
            return
324

Lianmin Zheng's avatar
Lianmin Zheng committed
325
    # Launch processes
Yuanhan Zhang's avatar
Yuanhan Zhang committed
326
    tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
327
328
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
Mingyi's avatar
Mingyi committed
329
    pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
Lianmin Zheng's avatar
Lianmin Zheng committed
330
331
    pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)

332
333
334
335
    if server_args.dp_size == 1:
        start_process = start_controller_process_single
    else:
        start_process = start_controller_process_multi
Mingyi's avatar
Mingyi committed
336
    proc_controller = mp.Process(
337
        target=start_process,
Mingyi's avatar
Mingyi committed
338
        args=(server_args, port_args, pipe_controller_writer, model_overide_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
339
    )
Mingyi's avatar
Mingyi committed
340
    proc_controller.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
344
345
346
347
348
349
350
351
    proc_detoken = mp.Process(
        target=start_detokenizer_process,
        args=(
            server_args,
            port_args,
            pipe_detoken_writer,
        ),
    )
    proc_detoken.start()

    # Wait for the model to finish loading
Mingyi's avatar
Mingyi committed
352
    controller_init_state = pipe_controller_reader.recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
    detoken_init_state = pipe_detoken_reader.recv()

Mingyi's avatar
Mingyi committed
355
356
    if controller_init_state != "init ok" or detoken_init_state != "init ok":
        proc_controller.kill()
Lianmin Zheng's avatar
Lianmin Zheng committed
357
        proc_detoken.kill()
358
359
360
361
        raise RuntimeError(
            "Initialization failed. "
            f"controller_init_state: {controller_init_state}, "
            f"detoken_init_state: {detoken_init_state}"
Yuanhan Zhang's avatar
Yuanhan Zhang committed
362
        )
Mingyi's avatar
Mingyi committed
363
    assert proc_controller.is_alive() and proc_detoken.is_alive()
Lianmin Zheng's avatar
Lianmin Zheng committed
364

365
366
367
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
368

369
    # Send a warmup request
zhyncs's avatar
zhyncs committed
370
    t = threading.Thread(
371
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
zhyncs's avatar
zhyncs committed
372
    )
373
    t.start()
374

375
    try:
376
        # Listen for requests
Lianmin Zheng's avatar
Lianmin Zheng committed
377
378
379
380
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
381
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
384
            timeout_keep_alive=5,
            loop="uvloop",
        )
385
386
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388


389
390
391
392
393
394
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
395
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416

    # Set ulimit
    set_ulimit()

    # Enable show time cost for debugging
    if server_args.show_time_cost:
        enable_show_time_cost()

    # Disable disk cache
    if server_args.disable_disk_cache:
        disable_cache()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
    if not server_args.disable_flashinfer:
        assert_pkg_version(
            "flashinfer",
417
            "0.1.5",
418
419
420
421
422
423
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )


424
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
Mingyi's avatar
Mingyi committed
425
426
427
    headers = {}
    url = server_args.url()
    if server_args.api_key:
428
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
429
430

    # Wait until the server is launched
431
    success = False
Mingyi's avatar
Mingyi committed
432
    for _ in range(120):
433
        time.sleep(1)
Mingyi's avatar
Mingyi committed
434
        try:
435
436
437
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
            assert res.status_code == 200, f"{res}"
            success = True
Mingyi's avatar
Mingyi committed
438
            break
439
440
        except (AssertionError, requests.exceptions.RequestException) as e:
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
441
            pass
442
    model_info = res.json()
Mingyi's avatar
Mingyi committed
443

444
445
446
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
447
448
449
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
450

Mingyi's avatar
Mingyi committed
451
    # Send a warmup request
452
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
453
    max_new_tokens = 8 if model_info["is_generation"] else 1
454
455
456
457
458
459
460
461
462
463
464
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
465
466
467
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
468
                url + request_name,
469
                json=json_data,
Mingyi's avatar
Mingyi committed
470
471
472
                headers=headers,
                timeout=600,
            )
473
            assert res.status_code == 200, f"{res}"
474
    except Exception:
475
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
476
        if pipe_finish_writer is not None:
477
            pipe_finish_writer.send(last_traceback)
478
479
480
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
Mingyi's avatar
Mingyi committed
481
482
483
484
485
486

    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
        pipe_finish_writer.send("init ok")


Lianmin Zheng's avatar
Lianmin Zheng committed
487
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
488
489
490
491
492
493
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
    def __init__(
        self,
496
        log_level: str = "error",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
497
        model_overide_args: Optional[dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
498
499
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
500
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
501
        """See the arguments in server_args.py::ServerArgs"""
502
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
503
504
505

        # Pre-allocate ports
        self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
506
507
            self.server_args.port,
            self.server_args.additional_ports,
508
            self.server_args.dp_size,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
509
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
510

Ying Sheng's avatar
Ying Sheng committed
511
512
513
514
        self.url = self.server_args.url()
        self.generate_url = (
            f"http://{self.server_args.host}:{self.server_args.port}/generate"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
515
516
517

        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
Yuanhan Zhang's avatar
Yuanhan Zhang committed
518
519
        proc = mp.Process(
            target=launch_server,
Mingyi's avatar
Mingyi committed
520
            args=(self.server_args, model_overide_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
521
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
522
        proc.start()
523
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
524
525
        self.pid = proc.pid

526
527
528
529
530
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

Lianmin Zheng's avatar
Lianmin Zheng committed
531
532
        if init_state != "init ok":
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
533
534
535
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
536
537
538
539
540

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
541
            kill_child_process(self.pid)
Lianmin Zheng's avatar
Lianmin Zheng committed
542
543
            self.pid = None

544
545
546
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
547
548
549
550
551
552
553
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

554
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
555
556
        self,
        prompt: str,
557
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
558
    ):
559
560
561
562
563
564
565
566
567
568
569
570
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
571
572
573
574
575
576
577
578
579
580
581
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
582
583
584
585
586
587
588
                        if hasattr(data, "text"):
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
589

590
591
592
593
594
595
596
    add_request = async_generate

    def generate(
        self,
        prompt: str,
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
597
        logprob_start_len: Optional[Union[List[int], int]] = None,
598
599
600
601
602
603
        top_logprobs_num: Optional[Union[List[int], int]] = None,
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
604
            "logprob_start_len": logprob_start_len,
605
606
607
608
609
610
611
612
            "top_logprobs_num": top_logprobs_num,
        }
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

613
614
615
616
617
618
619
620
621
622
623
624
625
    def encode(
        self,
        prompt: str,
    ):
        json_data = {
            "text": prompt,
        }
        response = requests.post(
            self.url + "/encode",
            json=json_data,
        )
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
626
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
627
        self.shutdown()