server.py 28.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
22
import atexit
Liangsheng Yin's avatar
Liangsheng Yin committed
23
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import json
25
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
27
import os
28
29
import re
import tempfile
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
import threading
import time
32
from http import HTTPStatus
33
34
35
from typing import AsyncIterator, Dict, List, Optional, Union

import orjson
36
from starlette.routing import Mount
Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
# Fix a bug of Python threading
39
40
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
41
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
import requests
import uvicorn
import uvloop
45
from fastapi import FastAPI, File, Form, Request, UploadFile
46
from fastapi.middleware.cors import CORSMiddleware
47
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
48
from uvicorn.config import LOGGING_CONFIG
Liangsheng Yin's avatar
Liangsheng Yin committed
49

Ying Sheng's avatar
Ying Sheng committed
50
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Ying Sheng's avatar
Ying Sheng committed
51
from sglang.srt.hf_transformers_utils import get_tokenizer
52
53
54
from sglang.srt.managers.data_parallel_controller import (
    run_data_parallel_controller_process,
)
55
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
56
57
58
59
60
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
    UpdateWeightReqInput,
)
61
from sglang.srt.managers.scheduler import run_scheduler_process
Lianmin Zheng's avatar
Lianmin Zheng committed
62
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
63
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
64
    load_chat_template_for_openai_api,
65
    v1_batches,
66
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
67
68
    v1_chat_completions,
    v1_completions,
69
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
70
    v1_embeddings,
71
72
73
74
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
75
)
Mingyi's avatar
Mingyi committed
76
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
77
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
78
from sglang.srt.utils import (
79
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    assert_pkg_version,
81
    configure_logger,
82
    delete_directory,
83
    is_port_available,
84
    kill_child_process,
85
    maybe_set_triton_cache_manager,
86
    prepare_model_and_tokenizer,
87
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
88
)
89
90
from sglang.utils import get_exception_traceback

91
92
logger = logging.getLogger(__name__)

93
94
95
96
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory

Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
99

Lianmin Zheng's avatar
Lianmin Zheng committed
100
app = FastAPI()
101
102
103
104
105
106
107
108
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

109
110
111
112
tokenizer_manager: TokenizerManager = None

##### Native API endpoints #####

Lianmin Zheng's avatar
Lianmin Zheng committed
113

114
115
116
117
118
119
120
121
122
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
123
124
125
126
127
128
129
130
131
132
133
134
    gri = GenerateReqInput(
        text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
    )
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
@app.get("/get_model_info")
async def get_model_info():
137
    """Get the model information."""
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
    result = {
        "model_path": tokenizer_manager.model_path,
140
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
141
142
143
    }
    return result

Cody Yu's avatar
Cody Yu committed
144

Liangsheng Yin's avatar
Liangsheng Yin committed
145
146
@app.get("/get_server_args")
async def get_server_args():
147
    """Get the server arguments."""
Liangsheng Yin's avatar
Liangsheng Yin committed
148
149
150
    return dataclasses.asdict(tokenizer_manager.server_args)


Lianmin Zheng's avatar
Lianmin Zheng committed
151
@app.post("/flush_cache")
Liangsheng Yin's avatar
Liangsheng Yin committed
152
async def flush_cache():
153
    """Flush the radix cache."""
154
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
155
    return Response(
156
157
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
158
159
160
161
        status_code=200,
    )


162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
    """Start profiling."""
    tokenizer_manager.start_profile()
    return Response(
        content="Start profiling.\n",
        status_code=200,
    )


@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
    """Stop profiling."""
    tokenizer_manager.stop_profile()
    return Response(
        content="Stop profiling. This will take some time.\n",
        status_code=200,
    )


184
185
186
187
188
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
    """Get the memory pool size in number of tokens"""
    try:
        ret = await tokenizer_manager.get_memory_pool_size()
189
190

        return ret
191
    except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
192
        return ORJSONResponse(
193
194
195
196
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


197
198
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):
199
    """Update the weights inplace without re-launching the server."""
200
    success, message = await tokenizer_manager.update_weights(obj, request)
Lianmin Zheng's avatar
Lianmin Zheng committed
201
    content = {"success": success, "message": message}
202
    if success:
203
        return ORJSONResponse(
204
205
206
207
            content,
            status_code=HTTPStatus.OK,
        )
    else:
208
        return ORJSONResponse(
209
210
211
212
213
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


214
# fastapi implicitly converts json in the request to obj (dataclass)
215
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
216
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
217
    if obj.stream:
218

219
        async def stream_results() -> AsyncIterator[bytes]:
220
221
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
222
223
224
                    yield b"data: " + orjson.dumps(
                        out, option=orjson.OPT_NON_STR_KEYS
                    ) + b"\n\n"
225
226
            except ValueError as e:
                out = {"error": {"message": str(e)}}
227
228
229
230
                yield b"data: " + orjson.dumps(
                    out, option=orjson.OPT_NON_STR_KEYS
                ) + b"\n\n"
            yield b"data: [DONE]\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
231

232
233
234
235
236
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
237
238
239
240
241
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
242
            return ORJSONResponse(
243
244
245
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
246

Ying Sheng's avatar
Ying Sheng committed
247
248
249
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
250

251
252
253
254
255
256
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
257
        return ORJSONResponse(
258
259
260
261
262
263
264
265
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


266
async def classify_request(obj: EmbeddingReqInput, request: Request):
267
    """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
268
269
270
271
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
272
        return ORJSONResponse(
273
274
275
276
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


277
278
app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
279
280


281
282
283
##### OpenAI-compatible API endpoints #####


Lianmin Zheng's avatar
Lianmin Zheng committed
284
@app.post("/v1/completions")
285
286
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
287
288


Cody Yu's avatar
Cody Yu committed
289
@app.post("/v1/chat/completions")
290
291
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
292

Lianmin Zheng's avatar
Lianmin Zheng committed
293

294
@app.post("/v1/embeddings", response_class=ORJSONResponse)
Ying Sheng's avatar
Ying Sheng committed
295
296
297
298
299
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


300
@app.get("/v1/models", response_class=ORJSONResponse)
301
302
303
304
305
306
307
308
309
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


310
311
312
313
314
315
316
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


317
318
319
320
321
322
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


323
324
325
326
327
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


328
329
330
331
332
333
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


351
def launch_engine(
zhyncs's avatar
zhyncs committed
352
353
    server_args: ServerArgs,
):
354
355
356
357
    """
    Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
358
359
    global tokenizer_manager

360
    # Configure global environment
361
    configure_logger(server_args)
362
363
    server_args.check_server_args()
    _set_envs_and_config(server_args)
364

365
    # Allocate ports for inter-process communications
366
    port_args = PortArgs.init_new(server_args)
367
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
368

369
370
371
372
    # If using model from www.modelscope.cn, first download the model.
    server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
        server_args.model_path, server_args.tokenizer_path
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    if server_args.dp_size == 1:
        # Launch tensor parallel scheduler processes
        scheduler_procs = []
        scheduler_pipe_readers = []
        tp_size_per_node = server_args.tp_size // server_args.nnodes
        tp_rank_range = range(
            tp_size_per_node * server_args.node_rank,
            tp_size_per_node * (server_args.node_rank + 1),
        )
        for tp_rank in tp_rank_range:
            reader, writer = mp.Pipe(duplex=False)
            gpu_id = tp_rank % tp_size_per_node
            proc = mp.Process(
                target=run_scheduler_process,
                args=(server_args, port_args, gpu_id, tp_rank, None, writer),
            )
            proc.start()
            scheduler_procs.append(proc)
            scheduler_pipe_readers.append(reader)

        if server_args.node_rank >= 1:
            # For other nodes, they do not need to run tokenizer or detokenizer,
            # so they can just wait here.
            while True:
                pass
    else:
        # Launch the data parallel controller
401
        reader, writer = mp.Pipe(duplex=False)
402
        scheduler_pipe_readers = [reader]
403
        proc = mp.Process(
404
405
            target=run_data_parallel_controller_process,
            args=(server_args, port_args, writer),
406
407
        )
        proc.start()
408

409
410
411
    # Launch detokenizer process
    detoken_proc = mp.Process(
        target=run_detokenizer_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
412
413
414
415
416
        args=(
            server_args,
            port_args,
        ),
    )
417
    detoken_proc.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
418

419
    # Launch tokenizer process
420
421
422
423
    tokenizer_manager = TokenizerManager(server_args, port_args)
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)

424
425
426
    # Wait for model to finish loading
    for i in range(len(scheduler_pipe_readers)):
        scheduler_pipe_readers[i].recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
427

Chayenne's avatar
Chayenne committed
428

429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
    """
    Launch SRT (SGLang Runtime) Server

    The SRT server consists of an HTTP server and the SRT engine.

    1. HTTP server: A FastAPI server that routes requests to the engine.
    2. SRT engine:
        1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
    1. The HTTP server and Tokenizer Manager both run in the main process.
    2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
    """

Yudi Xue's avatar
Yudi Xue committed
449
450
451
    if server_args.enable_metrics:
        _set_prometheus_env()

452
453
    launch_engine(server_args=server_args)

454
455
456
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
457

458
459
460
461
    # add prometheus middleware
    if server_args.enable_metrics:
        add_prometheus_middleware(app)

462
    # Send a warmup request
zhyncs's avatar
zhyncs committed
463
    t = threading.Thread(
Lianmin Zheng's avatar
Lianmin Zheng committed
464
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
zhyncs's avatar
zhyncs committed
465
    )
466
    t.start()
467

468
    try:
469
        # Listen for HTTP requests
470
471
472
        LOGGING_CONFIG["formatters"]["default"][
            "fmt"
        ] = "[%(asctime)s] %(levelprefix)s %(message)s"
473
        LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
474
475
476
        LOGGING_CONFIG["formatters"]["access"][
            "fmt"
        ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
477
        LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
Lianmin Zheng's avatar
Lianmin Zheng committed
478
479
480
481
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
482
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
483
484
485
            timeout_keep_alive=5,
            loop="uvloop",
        )
486
487
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
488

Chayenne's avatar
Chayenne committed
489

490
491
492
493
494
495
496
497
498
499
500
501
502
def add_prometheus_middleware(app: FastAPI):
    # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
    from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess

    registry = CollectorRegistry()
    multiprocess.MultiProcessCollector(registry)
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))

    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
    app.routes.append(metrics_route)


503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def _set_prometheus_env():
    # Set prometheus multiprocess directory
    # sglang uses prometheus multiprocess mode
    # we need to set this before importing prometheus_client
    # https://prometheus.github.io/client_python/multiprocess/
    global prometheus_multiproc_dir
    if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
        logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.")
        prometheus_multiproc_dir = tempfile.TemporaryDirectory(
            dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
        )
    else:
        prometheus_multiproc_dir = tempfile.TemporaryDirectory()
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
    logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
Lianmin Zheng's avatar
Lianmin Zheng committed
518

Chayenne's avatar
Chayenne committed
519

520
521
522
523
524
525
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
526
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
527
528
529
530
531
532
533
534
535
536

    # Set ulimit
    set_ulimit()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
537
    if server_args.attention_backend == "flashinfer":
538
539
        assert_pkg_version(
            "flashinfer",
540
            "0.1.6",
541
542
543
544
545
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )

546
    mp.set_start_method("spawn", force=True)
547

548

Lianmin Zheng's avatar
Lianmin Zheng committed
549
def _wait_and_warmup(server_args, pipe_finish_writer):
Mingyi's avatar
Mingyi committed
550
551
552
    headers = {}
    url = server_args.url()
    if server_args.api_key:
553
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
554
555

    # Wait until the server is launched
556
    success = False
Mingyi's avatar
Mingyi committed
557
    for _ in range(120):
558
        time.sleep(1)
Mingyi's avatar
Mingyi committed
559
        try:
560
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
561
            assert res.status_code == 200, f"{res=}, {res.text=}"
562
            success = True
Mingyi's avatar
Mingyi committed
563
            break
564
        except (AssertionError, requests.exceptions.RequestException):
565
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
566
567
            pass

568
569
570
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
571
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
572
        kill_child_process(include_self=True)
573
        return
574

575
    model_info = res.json()
576

Mingyi's avatar
Mingyi committed
577
    # Send a warmup request
578
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
579
    max_new_tokens = 8 if model_info["is_generation"] else 1
580
581
582
583
584
585
586
587
588
589
590
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
591
592
593
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
594
                url + request_name,
595
                json=json_data,
Mingyi's avatar
Mingyi committed
596
597
598
                headers=headers,
                timeout=600,
            )
599
            assert res.status_code == 200, f"{res}"
600
    except Exception:
601
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
602
        if pipe_finish_writer is not None:
603
            pipe_finish_writer.send(last_traceback)
604
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
605
        kill_child_process(include_self=True)
606
        return
Mingyi's avatar
Mingyi committed
607

608
609
    # logger.info(f"{res.json()=}")

Mingyi's avatar
Mingyi committed
610
611
    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
612
        pipe_finish_writer.send("ready")
Mingyi's avatar
Mingyi committed
613

614
615
616
    if server_args.delete_ckpt_after_loading:
        delete_directory(server_args.model_path)

Mingyi's avatar
Mingyi committed
617

Lianmin Zheng's avatar
Lianmin Zheng committed
618
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
621
622
623
624
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
625
626
    def __init__(
        self,
627
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
630
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
631
        """See the arguments in server_args.py::ServerArgs"""
632
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
633

634
635
636
        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

Lianmin Zheng's avatar
Lianmin Zheng committed
637
        # Pre-allocate ports
638
639
640
641
642
        for port in range(10000, 40000):
            if is_port_available(port):
                break
            port += 1
        self.server_args.port = port
Lianmin Zheng's avatar
Lianmin Zheng committed
643

Ying Sheng's avatar
Ying Sheng committed
644
        self.url = self.server_args.url()
645
        self.generate_url = self.url + "/generate"
Lianmin Zheng's avatar
Lianmin Zheng committed
646

647
        # NOTE: We store pid instead of proc to fix some issues during __delete__
Lianmin Zheng's avatar
Lianmin Zheng committed
648
649
        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
650

Yuanhan Zhang's avatar
Yuanhan Zhang committed
651
652
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
653
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
654
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
655
        proc.start()
656
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
657
658
        self.pid = proc.pid

659
660
661
662
663
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

664
        if init_state != "ready":
Lianmin Zheng's avatar
Lianmin Zheng committed
665
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
666
667
668
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
669
670
671
672
673

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
674
            kill_child_process(self.pid, include_self=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
675
676
            self.pid = None

677
678
679
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
680
681
682
683
684
685
686
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

687
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
688
689
        self,
        prompt: str,
690
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
691
    ):
692
693
694
695
696
697
698
699
700
701
702
703
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
704
705
706
707
708
709
710
711
712
713
714
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
Lianmin Zheng's avatar
Lianmin Zheng committed
715
                        if "text" in data:
716
717
718
719
720
721
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
722

723
724
725
726
    add_request = async_generate

    def generate(
        self,
727
        prompt: Union[str, List[str]],
728
729
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
730
        logprob_start_len: Optional[Union[List[int], int]] = None,
731
        top_logprobs_num: Optional[Union[List[int], int]] = None,
732
        lora_path: Optional[List[Optional[str]]] = None,
733
734
735
736
737
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
738
            "logprob_start_len": logprob_start_len,
739
            "top_logprobs_num": top_logprobs_num,
740
            "lora_path": lora_path,
741
        }
742
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
743
744
745
746
747
748
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

749
750
    def encode(
        self,
751
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
752
    ):
753
754
        json_data = {"text": prompt}
        response = requests.post(self.url + "/encode", json=json_data)
755
756
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
757
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
758
        self.shutdown()
759
760


Lianmin Zheng's avatar
Lianmin Zheng committed
761
762
763
764
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"


765
766
767
768
769
770
771
772
773
774
775
776
class Engine:
    """
    SRT Engine without an HTTP server layer.

    This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
    launching the HTTP server adds unnecessary complexity or overhead,
    """

    def __init__(self, *args, **kwargs):

        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)
777

Chayenne's avatar
Chayenne committed
778
779
780
        # runtime server default log level is log
        # offline engine works in scripts, so we set it to error

Chayenne's avatar
Chayenne committed
781
782
        if "log_level" not in kwargs:
            kwargs["log_level"] = "error"
783
784
785
786
787
788

        server_args = ServerArgs(*args, **kwargs)
        launch_engine(server_args=server_args)

    def generate(
        self,
789
790
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
791
        sampling_params: Optional[Dict] = None,
792
793
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
794
795
796
797
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
798
        stream: bool = False,
799
800
801
    ):
        obj = GenerateReqInput(
            text=prompt,
802
            input_ids=input_ids,
803
804
805
806
807
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
808
            stream=stream,
809
810
        )

811
812
        # get the current event loop
        loop = asyncio.get_event_loop()
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        ret = loop.run_until_complete(generate_request(obj, None))

        if stream is True:

            def generator_wrapper():
                offset = 0
                loop = asyncio.get_event_loop()
                generator = ret.body_iterator
                while True:
                    chunk = loop.run_until_complete(generator.__anext__())

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
            # however, it allows to wrap the generator as a subfunction and return
            return generator_wrapper()
        else:
            return ret

    async def async_generate(
        self,
840
841
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
842
        sampling_params: Optional[Dict] = None,
843
844
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
845
846
847
848
849
850
851
852
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
        stream: bool = False,
    ):
        obj = GenerateReqInput(
            text=prompt,
853
            input_ids=input_ids,
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
            stream=stream,
        )

        ret = await generate_request(obj, None)

        if stream is True:
            generator = ret.body_iterator

            async def generator_wrapper():

                offset = 0

                while True:
                    chunk = await generator.__anext__()

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            return generator_wrapper()
        else:
            return ret
885
886

    def shutdown(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
887
        kill_child_process()
888

889
890
891
892
893
894
895
896
    def get_tokenizer(self):
        global tokenizer_manager

        if tokenizer_manager is None:
            raise ReferenceError("Tokenizer Manager is not initialized.")
        else:
            return tokenizer_manager.tokenizer

897
    # TODO (ByronHsu): encode