server.py 25.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
22
import atexit
Liangsheng Yin's avatar
Liangsheng Yin committed
23
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import json
25
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
27
import os
28
import random
Lianmin Zheng's avatar
Lianmin Zheng committed
29
30
import threading
import time
31
from http import HTTPStatus
32
from typing import Dict, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
33

34
# Fix a bug of Python threading
35
36
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
37
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40
import requests
import uvicorn
import uvloop
41
from fastapi import FastAPI, File, Form, Request, UploadFile
42
from fastapi.middleware.cors import CORSMiddleware
43
from fastapi.responses import JSONResponse, Response, StreamingResponse
Liangsheng Yin's avatar
Liangsheng Yin committed
44

Ying Sheng's avatar
Ying Sheng committed
45
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Ying Sheng's avatar
Ying Sheng committed
46
from sglang.srt.hf_transformers_utils import get_tokenizer
47
48
49
from sglang.srt.managers.data_parallel_controller import (
    run_data_parallel_controller_process,
)
50
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
51
52
53
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
54
    RewardReqInput,
55
56
    UpdateWeightReqInput,
)
57
from sglang.srt.managers.scheduler import run_scheduler_process
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
59
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
60
    load_chat_template_for_openai_api,
61
    v1_batches,
62
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
63
64
    v1_chat_completions,
    v1_completions,
65
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
66
    v1_embeddings,
67
68
69
70
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
71
)
Mingyi's avatar
Mingyi committed
72
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
73
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
74
from sglang.srt.utils import (
75
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
76
    assert_pkg_version,
77
    configure_logger,
78
    is_port_available,
79
    kill_child_process,
80
    maybe_set_triton_cache_manager,
81
    prepare_model_and_tokenizer,
82
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
83
)
84
85
from sglang.utils import get_exception_traceback

86
87
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
90

Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
93
app = FastAPI()
tokenizer_manager = None

94
95
96
97
98
99
100
101
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Lianmin Zheng's avatar
Lianmin Zheng committed
102

103
104
105
106
107
108
109
110
111
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
112
113
114
115
116
117
118
119
120
121
122
123
    gri = GenerateReqInput(
        text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
    )
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
124
125
@app.get("/get_model_info")
async def get_model_info():
126
    """Get the model information."""
Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
    result = {
        "model_path": tokenizer_manager.model_path,
129
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131
132
    }
    return result

Cody Yu's avatar
Cody Yu committed
133

Liangsheng Yin's avatar
Liangsheng Yin committed
134
135
@app.get("/get_server_args")
async def get_server_args():
136
    """Get the server arguments."""
Liangsheng Yin's avatar
Liangsheng Yin committed
137
138
139
    return dataclasses.asdict(tokenizer_manager.server_args)


Liangsheng Yin's avatar
Liangsheng Yin committed
140
141
@app.get("/flush_cache")
async def flush_cache():
142
    """Flush the radix cache."""
143
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
144
    return Response(
145
146
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
147
148
149
150
        status_code=200,
    )


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
    """Start profiling."""
    tokenizer_manager.start_profile()
    return Response(
        content="Start profiling.\n",
        status_code=200,
    )


@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
    """Stop profiling."""
    tokenizer_manager.stop_profile()
    return Response(
        content="Stop profiling. This will take some time.\n",
        status_code=200,
    )


173
174
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):
175
    """Update the weights inplace without re-launching the server."""
176
    success, message = await tokenizer_manager.update_weights(obj, request)
Lianmin Zheng's avatar
Lianmin Zheng committed
177
    content = {"success": success, "message": message}
178
179
180
181
182
183
184
185
186
187
188
189
    if success:
        return JSONResponse(
            content,
            status_code=HTTPStatus.OK,
        )
    else:
        return JSONResponse(
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


190
# fastapi implicitly converts json in the request to obj (dataclass)
191
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
192
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
193
    if obj.stream:
194

Lianmin Zheng's avatar
Lianmin Zheng committed
195
        async def stream_results():
196
197
198
199
200
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
                    yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            except ValueError as e:
                out = {"error": {"message": str(e)}}
Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
203
                yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
            yield "data: [DONE]\n\n"

204
205
206
207
208
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
209
210
211
212
213
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
214
215
216
217
            return JSONResponse(
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
218

Ying Sheng's avatar
Ying Sheng committed
219
220
221
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
222

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
        return JSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


238
async def judge_request(obj: RewardReqInput, request: Request):
239
    """Handle a reward model request."""
240
241
242
243
244
245
246
247
248
249
250
251
252
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
        return JSONResponse(
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/judge")(judge_request)
app.put("/judge")(judge_request)


Lianmin Zheng's avatar
Lianmin Zheng committed
253
@app.post("/v1/completions")
254
255
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
256
257


Cody Yu's avatar
Cody Yu committed
258
@app.post("/v1/chat/completions")
259
260
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
261

Lianmin Zheng's avatar
Lianmin Zheng committed
262

Ying Sheng's avatar
Ying Sheng committed
263
264
265
266
267
268
@app.post("/v1/embeddings")
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


269
270
271
272
273
274
275
276
277
278
@app.get("/v1/models")
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


279
280
281
282
283
284
285
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


286
287
288
289
290
291
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


292
293
294
295
296
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


297
298
299
300
301
302
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


320
def launch_engine(
zhyncs's avatar
zhyncs committed
321
322
    server_args: ServerArgs,
):
323
324
325
326
    """
    Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
327
328
    global tokenizer_manager

329
    # Configure global environment
330
    configure_logger(server_args)
331
332
    server_args.check_server_args()
    _set_envs_and_config(server_args)
333

334
    # Allocate ports for inter-process communications
335
    port_args = PortArgs.init_new(server_args)
336
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
337

338
339
340
341
    # If using model from www.modelscope.cn, first download the model.
    server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
        server_args.model_path, server_args.tokenizer_path
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
342

343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    if server_args.dp_size == 1:
        # Launch tensor parallel scheduler processes
        scheduler_procs = []
        scheduler_pipe_readers = []
        tp_size_per_node = server_args.tp_size // server_args.nnodes
        tp_rank_range = range(
            tp_size_per_node * server_args.node_rank,
            tp_size_per_node * (server_args.node_rank + 1),
        )
        for tp_rank in tp_rank_range:
            reader, writer = mp.Pipe(duplex=False)
            gpu_id = tp_rank % tp_size_per_node
            proc = mp.Process(
                target=run_scheduler_process,
                args=(server_args, port_args, gpu_id, tp_rank, None, writer),
            )
            proc.start()
            scheduler_procs.append(proc)
            scheduler_pipe_readers.append(reader)

        if server_args.node_rank >= 1:
            # For other nodes, they do not need to run tokenizer or detokenizer,
            # so they can just wait here.
            while True:
                pass
    else:
        # Launch the data parallel controller
370
        reader, writer = mp.Pipe(duplex=False)
371
        scheduler_pipe_readers = [reader]
372
        proc = mp.Process(
373
374
            target=run_data_parallel_controller_process,
            args=(server_args, port_args, writer),
375
376
        )
        proc.start()
377

378
379
380
    # Launch detokenizer process
    detoken_proc = mp.Process(
        target=run_detokenizer_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
383
384
385
        args=(
            server_args,
            port_args,
        ),
    )
386
    detoken_proc.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
387

388
    # Launch tokenizer process
389
390
391
392
    tokenizer_manager = TokenizerManager(server_args, port_args)
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)

393
394
395
    # Wait for model to finish loading
    for i in range(len(scheduler_pipe_readers)):
        scheduler_pipe_readers[i].recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
396

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
    """
    Launch SRT (SGLang Runtime) Server

    The SRT server consists of an HTTP server and the SRT engine.

    1. HTTP server: A FastAPI server that routes requests to the engine.
    2. SRT engine:
        1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
    1. The HTTP server and Tokenizer Manager both run in the main process.
    2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
    """

    launch_engine(server_args=server_args)

420
421
422
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
423

424
    # Send a warmup request
zhyncs's avatar
zhyncs committed
425
    t = threading.Thread(
426
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
zhyncs's avatar
zhyncs committed
427
    )
428
    t.start()
429

430
    try:
431
        # Listen for HTTP requests
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
434
435
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
436
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
437
438
439
            timeout_keep_alive=5,
            loop="uvloop",
        )
440
441
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443


444
445
446
447
448
449
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
450
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
451
452
453
454
455
456
457
458
459
460

    # Set ulimit
    set_ulimit()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
461
    if server_args.attention_backend == "flashinfer":
462
463
        assert_pkg_version(
            "flashinfer",
464
            "0.1.6",
465
466
467
468
469
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )

470
    mp.set_start_method("spawn", force=True)
471

472

473
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
Mingyi's avatar
Mingyi committed
474
475
476
    headers = {}
    url = server_args.url()
    if server_args.api_key:
477
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
478
479

    # Wait until the server is launched
480
    success = False
Mingyi's avatar
Mingyi committed
481
    for _ in range(120):
482
        time.sleep(1)
Mingyi's avatar
Mingyi committed
483
        try:
484
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
485
            assert res.status_code == 200, f"{res=}, {res.text=}"
486
            success = True
Mingyi's avatar
Mingyi committed
487
            break
488
        except (AssertionError, requests.exceptions.RequestException):
489
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
490
491
            pass

492
493
494
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
495
496
497
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
498

499
    model_info = res.json()
Mingyi's avatar
Mingyi committed
500
    # Send a warmup request
501
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
502
    max_new_tokens = 8 if model_info["is_generation"] else 1
503
504
505
506
507
508
509
510
511
512
513
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
514
515
516
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
517
                url + request_name,
518
                json=json_data,
Mingyi's avatar
Mingyi committed
519
520
521
                headers=headers,
                timeout=600,
            )
522
            assert res.status_code == 200, f"{res}"
523
    except Exception:
524
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
525
        if pipe_finish_writer is not None:
526
            pipe_finish_writer.send(last_traceback)
527
528
529
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_child_process(pid, including_parent=False)
        return
Mingyi's avatar
Mingyi committed
530

531
    print(f"{res.json()=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
532

Mingyi's avatar
Mingyi committed
533
534
    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
535
        pipe_finish_writer.send("ready")
Mingyi's avatar
Mingyi committed
536
537


Lianmin Zheng's avatar
Lianmin Zheng committed
538
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
539
540
541
542
543
544
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
545
546
    def __init__(
        self,
547
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
548
549
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
550
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
551
        """See the arguments in server_args.py::ServerArgs"""
552
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
553

554
555
556
        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

Lianmin Zheng's avatar
Lianmin Zheng committed
557
        # Pre-allocate ports
558
559
560
561
562
        for port in range(10000, 40000):
            if is_port_available(port):
                break
            port += 1
        self.server_args.port = port
Lianmin Zheng's avatar
Lianmin Zheng committed
563

Ying Sheng's avatar
Ying Sheng committed
564
        self.url = self.server_args.url()
565
        self.generate_url = self.url + "/generate"
Lianmin Zheng's avatar
Lianmin Zheng committed
566

567
        # NOTE: We store pid instead of proc to fix some issues during __delete__
Lianmin Zheng's avatar
Lianmin Zheng committed
568
569
        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
570

Yuanhan Zhang's avatar
Yuanhan Zhang committed
571
572
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
573
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
574
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
575
        proc.start()
576
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
577
578
        self.pid = proc.pid

579
580
581
582
583
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

584
        if init_state != "ready":
Lianmin Zheng's avatar
Lianmin Zheng committed
585
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
586
587
588
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
589
590
591
592
593

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
594
            kill_child_process(self.pid)
Lianmin Zheng's avatar
Lianmin Zheng committed
595
596
            self.pid = None

597
598
599
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
600
601
602
603
604
605
606
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

607
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
608
609
        self,
        prompt: str,
610
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
611
    ):
612
613
614
615
616
617
618
619
620
621
622
623
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
624
625
626
627
628
629
630
631
632
633
634
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
Lianmin Zheng's avatar
Lianmin Zheng committed
635
                        if "text" in data:
636
637
638
639
640
641
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
642

643
644
645
646
    add_request = async_generate

    def generate(
        self,
647
        prompt: Union[str, List[str]],
648
649
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
650
        logprob_start_len: Optional[Union[List[int], int]] = None,
651
        top_logprobs_num: Optional[Union[List[int], int]] = None,
652
        lora_path: Optional[List[Optional[str]]] = None,
653
654
655
656
657
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
658
            "logprob_start_len": logprob_start_len,
659
            "top_logprobs_num": top_logprobs_num,
660
            "lora_path": lora_path,
661
        }
662
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
663
664
665
666
667
668
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

669
670
    def encode(
        self,
671
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
672
    ):
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        if isinstance(prompt, str) or isinstance(prompt[0], str):
            # embedding
            json_data = {
                "text": prompt,
            }
            response = requests.post(
                self.url + "/encode",
                json=json_data,
            )
        else:
            # reward
            json_data = {
                "conv": prompt,
            }
            response = requests.post(
                self.url + "/judge",
                json=json_data,
            )
691
692
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
693
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
694
        self.shutdown()
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720


class Engine:
    """
    SRT Engine without an HTTP server layer.

    This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
    launching the HTTP server adds unnecessary complexity or overhead,
    """

    def __init__(self, *args, **kwargs):

        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

        server_args = ServerArgs(*args, **kwargs)
        launch_engine(server_args=server_args)

    def generate(
        self,
        prompt: Union[str, List[str]],
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
721
        stream: bool = False,
722
    ):
723
724
        # TODO (ByronHsu): refactor to reduce the duplicated code

725
726
727
728
729
730
731
        obj = GenerateReqInput(
            text=prompt,
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
732
            stream=stream,
733
734
        )

735
736
        # get the current event loop
        loop = asyncio.get_event_loop()
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
        ret = loop.run_until_complete(generate_request(obj, None))

        if stream is True:
            STREAM_END_SYMBOL = "data: [DONE]"
            STREAM_CHUNK_START_SYMBOL = "data:"

            def generator_wrapper():
                offset = 0
                loop = asyncio.get_event_loop()
                generator = ret.body_iterator
                while True:
                    chunk = loop.run_until_complete(generator.__anext__())

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
            # however, it allows to wrap the generator as a subfunction and return
            return generator_wrapper()
        else:
            return ret

    async def async_generate(
        self,
        prompt: Union[str, List[str]],
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
        stream: bool = False,
    ):
        obj = GenerateReqInput(
            text=prompt,
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
            stream=stream,
        )

        ret = await generate_request(obj, None)

        if stream is True:
            STREAM_END_SYMBOL = "data: [DONE]"
            STREAM_CHUNK_START_SYMBOL = "data:"

            generator = ret.body_iterator

            async def generator_wrapper():

                offset = 0

                while True:
                    chunk = await generator.__anext__()

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            return generator_wrapper()
        else:
            return ret
810
811
812
813

    def shutdown(self):
        kill_child_process(os.getpid(), including_parent=False)

814
815
816
817
818
819
820
821
    def get_tokenizer(self):
        global tokenizer_manager

        if tokenizer_manager is None:
            raise ReferenceError("Tokenizer Manager is not initialized.")
        else:
            return tokenizer_manager.tokenizer

822
    # TODO (ByronHsu): encode