server.py 27.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
22
import atexit
Liangsheng Yin's avatar
Liangsheng Yin committed
23
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import json
25
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
27
import os
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
import threading
import time
30
from http import HTTPStatus
31
32
33
from typing import AsyncIterator, Dict, List, Optional, Union

import orjson
Lianmin Zheng's avatar
Lianmin Zheng committed
34

35
# Fix a bug of Python threading
36
37
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
38
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
39
40
41
import requests
import uvicorn
import uvloop
42
from fastapi import FastAPI, File, Form, Request, UploadFile
43
from fastapi.middleware.cors import CORSMiddleware
44
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
45
from uvicorn.config import LOGGING_CONFIG
Liangsheng Yin's avatar
Liangsheng Yin committed
46

Ying Sheng's avatar
Ying Sheng committed
47
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Ying Sheng's avatar
Ying Sheng committed
48
from sglang.srt.hf_transformers_utils import get_tokenizer
49
50
51
from sglang.srt.managers.data_parallel_controller import (
    run_data_parallel_controller_process,
)
52
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
53
54
55
56
57
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
    UpdateWeightReqInput,
)
58
from sglang.srt.managers.scheduler import run_scheduler_process
Lianmin Zheng's avatar
Lianmin Zheng committed
59
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
60
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
61
    load_chat_template_for_openai_api,
62
    v1_batches,
63
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
    v1_chat_completions,
    v1_completions,
66
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
67
    v1_embeddings,
68
69
70
71
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
72
)
Mingyi's avatar
Mingyi committed
73
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
74
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
75
from sglang.srt.utils import (
76
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
77
    assert_pkg_version,
78
    configure_logger,
79
    is_port_available,
80
    kill_child_process,
81
    maybe_set_triton_cache_manager,
82
    prepare_model_and_tokenizer,
83
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
84
)
85
86
from sglang.utils import get_exception_traceback

87
88
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
91

Lianmin Zheng's avatar
Lianmin Zheng committed
92
app = FastAPI()
93
tokenizer_manager: TokenizerManager = None
Lianmin Zheng's avatar
Lianmin Zheng committed
94

95
96
97
98
99
100
101
102
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Lianmin Zheng's avatar
Lianmin Zheng committed
103

104
105
106
107
108
109
110
111
112
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
113
114
115
116
117
118
119
120
121
122
123
124
    gri = GenerateReqInput(
        text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
    )
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
@app.get("/get_model_info")
async def get_model_info():
127
    """Get the model information."""
Lianmin Zheng's avatar
Lianmin Zheng committed
128
129
    result = {
        "model_path": tokenizer_manager.model_path,
130
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
133
    }
    return result

Cody Yu's avatar
Cody Yu committed
134

Liangsheng Yin's avatar
Liangsheng Yin committed
135
136
@app.get("/get_server_args")
async def get_server_args():
137
    """Get the server arguments."""
Liangsheng Yin's avatar
Liangsheng Yin committed
138
139
140
    return dataclasses.asdict(tokenizer_manager.server_args)


Lianmin Zheng's avatar
Lianmin Zheng committed
141
@app.post("/flush_cache")
Liangsheng Yin's avatar
Liangsheng Yin committed
142
async def flush_cache():
143
    """Flush the radix cache."""
144
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
145
    return Response(
146
147
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
148
149
150
151
        status_code=200,
    )


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
    """Start profiling."""
    tokenizer_manager.start_profile()
    return Response(
        content="Start profiling.\n",
        status_code=200,
    )


@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
    """Stop profiling."""
    tokenizer_manager.stop_profile()
    return Response(
        content="Stop profiling. This will take some time.\n",
        status_code=200,
    )


174
175
176
177
178
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
    """Get the memory pool size in number of tokens"""
    try:
        ret = await tokenizer_manager.get_memory_pool_size()
179
180

        return ret
181
    except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
182
        return ORJSONResponse(
183
184
185
186
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


187
188
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):
189
    """Update the weights inplace without re-launching the server."""
190
    success, message = await tokenizer_manager.update_weights(obj, request)
Lianmin Zheng's avatar
Lianmin Zheng committed
191
    content = {"success": success, "message": message}
192
    if success:
193
        return ORJSONResponse(
194
195
196
197
            content,
            status_code=HTTPStatus.OK,
        )
    else:
198
        return ORJSONResponse(
199
200
201
202
203
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


204
# fastapi implicitly converts json in the request to obj (dataclass)
205
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
206
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
207
    if obj.stream:
208

209
        async def stream_results() -> AsyncIterator[bytes]:
210
211
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
212
213
214
                    yield b"data: " + orjson.dumps(
                        out, option=orjson.OPT_NON_STR_KEYS
                    ) + b"\n\n"
215
216
            except ValueError as e:
                out = {"error": {"message": str(e)}}
217
218
219
220
                yield b"data: " + orjson.dumps(
                    out, option=orjson.OPT_NON_STR_KEYS
                ) + b"\n\n"
            yield b"data: [DONE]\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
221

222
223
224
225
226
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
227
228
229
230
231
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
232
            return ORJSONResponse(
233
234
235
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
236

Ying Sheng's avatar
Ying Sheng committed
237
238
239
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
240

241
242
243
244
245
246
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
247
        return ORJSONResponse(
248
249
250
251
252
253
254
255
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


256
async def judge_request(obj: EmbeddingReqInput, request: Request):
257
    """Handle a reward model request."""
258
259
260
261
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
262
        return ORJSONResponse(
263
264
265
266
267
268
269
270
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/judge")(judge_request)
app.put("/judge")(judge_request)


Lianmin Zheng's avatar
Lianmin Zheng committed
271
@app.post("/v1/completions")
272
273
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
274
275


Cody Yu's avatar
Cody Yu committed
276
@app.post("/v1/chat/completions")
277
278
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
279

Lianmin Zheng's avatar
Lianmin Zheng committed
280

281
@app.post("/v1/embeddings", response_class=ORJSONResponse)
Ying Sheng's avatar
Ying Sheng committed
282
283
284
285
286
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


287
@app.get("/v1/models", response_class=ORJSONResponse)
288
289
290
291
292
293
294
295
296
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


297
298
299
300
301
302
303
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


304
305
306
307
308
309
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


310
311
312
313
314
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


315
316
317
318
319
320
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


338
def launch_engine(
zhyncs's avatar
zhyncs committed
339
340
    server_args: ServerArgs,
):
341
342
343
344
    """
    Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
    global tokenizer_manager

347
    # Configure global environment
348
    configure_logger(server_args)
349
350
    server_args.check_server_args()
    _set_envs_and_config(server_args)
351

352
    # Allocate ports for inter-process communications
353
    port_args = PortArgs.init_new(server_args)
354
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
355

356
357
358
359
    # If using model from www.modelscope.cn, first download the model.
    server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
        server_args.model_path, server_args.tokenizer_path
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    if server_args.dp_size == 1:
        # Launch tensor parallel scheduler processes
        scheduler_procs = []
        scheduler_pipe_readers = []
        tp_size_per_node = server_args.tp_size // server_args.nnodes
        tp_rank_range = range(
            tp_size_per_node * server_args.node_rank,
            tp_size_per_node * (server_args.node_rank + 1),
        )
        for tp_rank in tp_rank_range:
            reader, writer = mp.Pipe(duplex=False)
            gpu_id = tp_rank % tp_size_per_node
            proc = mp.Process(
                target=run_scheduler_process,
                args=(server_args, port_args, gpu_id, tp_rank, None, writer),
            )
            proc.start()
            scheduler_procs.append(proc)
            scheduler_pipe_readers.append(reader)

        if server_args.node_rank >= 1:
            # For other nodes, they do not need to run tokenizer or detokenizer,
            # so they can just wait here.
            while True:
                pass
    else:
        # Launch the data parallel controller
388
        reader, writer = mp.Pipe(duplex=False)
389
        scheduler_pipe_readers = [reader]
390
        proc = mp.Process(
391
392
            target=run_data_parallel_controller_process,
            args=(server_args, port_args, writer),
393
394
        )
        proc.start()
395

396
397
398
    # Launch detokenizer process
    detoken_proc = mp.Process(
        target=run_detokenizer_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
399
400
401
402
403
        args=(
            server_args,
            port_args,
        ),
    )
404
    detoken_proc.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
405

406
    # Launch tokenizer process
407
408
409
410
    tokenizer_manager = TokenizerManager(server_args, port_args)
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)

411
412
413
    # Wait for model to finish loading
    for i in range(len(scheduler_pipe_readers)):
        scheduler_pipe_readers[i].recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
414

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
    """
    Launch SRT (SGLang Runtime) Server

    The SRT server consists of an HTTP server and the SRT engine.

    1. HTTP server: A FastAPI server that routes requests to the engine.
    2. SRT engine:
        1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
    1. The HTTP server and Tokenizer Manager both run in the main process.
    2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
    """

    launch_engine(server_args=server_args)

438
439
440
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
441

442
    # Send a warmup request
zhyncs's avatar
zhyncs committed
443
    t = threading.Thread(
Lianmin Zheng's avatar
Lianmin Zheng committed
444
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
zhyncs's avatar
zhyncs committed
445
    )
446
    t.start()
447

448
    try:
449
        # Listen for HTTP requests
450
451
452
        LOGGING_CONFIG["formatters"]["default"][
            "fmt"
        ] = "[%(asctime)s] %(levelprefix)s %(message)s"
453
        LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
454
455
456
        LOGGING_CONFIG["formatters"]["access"][
            "fmt"
        ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
457
        LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
Lianmin Zheng's avatar
Lianmin Zheng committed
458
459
460
461
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
462
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
463
464
465
            timeout_keep_alive=5,
            loop="uvloop",
        )
466
467
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
468
469


470
471
472
473
474
475
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
476
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
477
478
479
480
481
482
483
484
485
486

    # Set ulimit
    set_ulimit()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
487
    if server_args.attention_backend == "flashinfer":
488
489
        assert_pkg_version(
            "flashinfer",
490
            "0.1.6",
491
492
493
494
495
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )

496
    mp.set_start_method("spawn", force=True)
497

498

Lianmin Zheng's avatar
Lianmin Zheng committed
499
def _wait_and_warmup(server_args, pipe_finish_writer):
Mingyi's avatar
Mingyi committed
500
501
502
    headers = {}
    url = server_args.url()
    if server_args.api_key:
503
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
504
505

    # Wait until the server is launched
506
    success = False
Mingyi's avatar
Mingyi committed
507
    for _ in range(120):
508
        time.sleep(1)
Mingyi's avatar
Mingyi committed
509
        try:
510
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
511
            assert res.status_code == 200, f"{res=}, {res.text=}"
512
            success = True
Mingyi's avatar
Mingyi committed
513
            break
514
        except (AssertionError, requests.exceptions.RequestException):
515
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
516
517
            pass

518
519
520
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
521
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
522
        kill_child_process(include_self=True)
523
        return
524

525
    model_info = res.json()
Mingyi's avatar
Mingyi committed
526
    # Send a warmup request
527
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
528
    max_new_tokens = 8 if model_info["is_generation"] else 1
529
530
531
532
533
534
535
536
537
538
539
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
540
541
542
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
543
                url + request_name,
544
                json=json_data,
Mingyi's avatar
Mingyi committed
545
546
547
                headers=headers,
                timeout=600,
            )
548
            assert res.status_code == 200, f"{res}"
549
    except Exception:
550
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
551
        if pipe_finish_writer is not None:
552
            pipe_finish_writer.send(last_traceback)
553
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
554
        kill_child_process(include_self=True)
555
        return
Mingyi's avatar
Mingyi committed
556

557
558
    # logger.info(f"{res.json()=}")

Mingyi's avatar
Mingyi committed
559
560
    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
561
        pipe_finish_writer.send("ready")
Mingyi's avatar
Mingyi committed
562
563


Lianmin Zheng's avatar
Lianmin Zheng committed
564
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
565
566
567
568
569
570
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
571
572
    def __init__(
        self,
573
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
574
575
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
576
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
577
        """See the arguments in server_args.py::ServerArgs"""
578
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
579

580
581
582
        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

Lianmin Zheng's avatar
Lianmin Zheng committed
583
        # Pre-allocate ports
584
585
586
587
588
        for port in range(10000, 40000):
            if is_port_available(port):
                break
            port += 1
        self.server_args.port = port
Lianmin Zheng's avatar
Lianmin Zheng committed
589

Ying Sheng's avatar
Ying Sheng committed
590
        self.url = self.server_args.url()
591
        self.generate_url = self.url + "/generate"
Lianmin Zheng's avatar
Lianmin Zheng committed
592

593
        # NOTE: We store pid instead of proc to fix some issues during __delete__
Lianmin Zheng's avatar
Lianmin Zheng committed
594
595
        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
596

Yuanhan Zhang's avatar
Yuanhan Zhang committed
597
598
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
599
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
600
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
601
        proc.start()
602
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
603
604
        self.pid = proc.pid

605
606
607
608
609
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

610
        if init_state != "ready":
Lianmin Zheng's avatar
Lianmin Zheng committed
611
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
612
613
614
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
615
616
617
618
619

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
620
            kill_child_process(self.pid, include_self=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
621
622
            self.pid = None

623
624
625
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
626
627
628
629
630
631
632
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

633
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
634
635
        self,
        prompt: str,
636
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
637
    ):
638
639
640
641
642
643
644
645
646
647
648
649
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
650
651
652
653
654
655
656
657
658
659
660
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
Lianmin Zheng's avatar
Lianmin Zheng committed
661
                        if "text" in data:
662
663
664
665
666
667
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
668

669
670
671
672
    add_request = async_generate

    def generate(
        self,
673
        prompt: Union[str, List[str]],
674
675
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
676
        logprob_start_len: Optional[Union[List[int], int]] = None,
677
        top_logprobs_num: Optional[Union[List[int], int]] = None,
678
        lora_path: Optional[List[Optional[str]]] = None,
679
680
681
682
683
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
684
            "logprob_start_len": logprob_start_len,
685
            "top_logprobs_num": top_logprobs_num,
686
            "lora_path": lora_path,
687
        }
688
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
689
690
691
692
693
694
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

695
696
    def encode(
        self,
697
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
698
    ):
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        if isinstance(prompt, str) or isinstance(prompt[0], str):
            # embedding
            json_data = {
                "text": prompt,
            }
            response = requests.post(
                self.url + "/encode",
                json=json_data,
            )
        else:
            # reward
            json_data = {
                "conv": prompt,
            }
            response = requests.post(
                self.url + "/judge",
                json=json_data,
            )
717
718
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
719
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
720
        self.shutdown()
721
722


Lianmin Zheng's avatar
Lianmin Zheng committed
723
724
725
726
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"


727
728
729
730
731
732
733
734
735
736
737
738
class Engine:
    """
    SRT Engine without an HTTP server layer.

    This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
    launching the HTTP server adds unnecessary complexity or overhead,
    """

    def __init__(self, *args, **kwargs):

        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)
Chayenne's avatar
Chayenne committed
739
740
741
742
743
744
        
        # runtime server default log level is log
        # offline engine works in scripts, so we set it to error

        if 'log_level' not in kwargs:
            kwargs['log_level'] = 'error'
745
746
747
748
749
750

        server_args = ServerArgs(*args, **kwargs)
        launch_engine(server_args=server_args)

    def generate(
        self,
751
752
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
753
        sampling_params: Optional[Dict] = None,
754
755
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
756
757
758
759
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
760
        stream: bool = False,
761
762
763
    ):
        obj = GenerateReqInput(
            text=prompt,
764
            input_ids=input_ids,
765
766
767
768
769
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
770
            stream=stream,
771
772
        )

773
774
        # get the current event loop
        loop = asyncio.get_event_loop()
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
        ret = loop.run_until_complete(generate_request(obj, None))

        if stream is True:

            def generator_wrapper():
                offset = 0
                loop = asyncio.get_event_loop()
                generator = ret.body_iterator
                while True:
                    chunk = loop.run_until_complete(generator.__anext__())

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
            # however, it allows to wrap the generator as a subfunction and return
            return generator_wrapper()
        else:
            return ret

    async def async_generate(
        self,
802
803
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
804
        sampling_params: Optional[Dict] = None,
805
806
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
807
808
809
810
811
812
813
814
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
        stream: bool = False,
    ):
        obj = GenerateReqInput(
            text=prompt,
815
            input_ids=input_ids,
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
            stream=stream,
        )

        ret = await generate_request(obj, None)

        if stream is True:
            generator = ret.body_iterator

            async def generator_wrapper():

                offset = 0

                while True:
                    chunk = await generator.__anext__()

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            return generator_wrapper()
        else:
            return ret
847
848

    def shutdown(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
849
        kill_child_process()
850

851
852
853
854
855
856
857
858
    def get_tokenizer(self):
        global tokenizer_manager

        if tokenizer_manager is None:
            raise ReferenceError("Tokenizer Manager is not initialized.")
        else:
            return tokenizer_manager.tokenizer

859
    # TODO (ByronHsu): encode