server.py 28.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
18
19
"""
The entry point of inference server.
SRT = SGLang Runtime.
"""
20

Lianmin Zheng's avatar
Lianmin Zheng committed
21
import asyncio
22
import atexit
Liangsheng Yin's avatar
Liangsheng Yin committed
23
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import json
25
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import multiprocessing as mp
Cody Yu's avatar
Cody Yu committed
27
import os
28
29
import re
import tempfile
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31
import threading
import time
32
from http import HTTPStatus
33
34
35
from typing import AsyncIterator, Dict, List, Optional, Union

import orjson
36
from starlette.routing import Mount
Lianmin Zheng's avatar
Lianmin Zheng committed
37

38
# Fix a bug of Python threading
39
40
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

Ying Sheng's avatar
Ying Sheng committed
41
import aiohttp
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
import requests
import uvicorn
import uvloop
45
from fastapi import FastAPI, File, Form, Request, UploadFile
46
from fastapi.middleware.cors import CORSMiddleware
47
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
48
from uvicorn.config import LOGGING_CONFIG
Liangsheng Yin's avatar
Liangsheng Yin committed
49

Ying Sheng's avatar
Ying Sheng committed
50
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
Ying Sheng's avatar
Ying Sheng committed
51
from sglang.srt.hf_transformers_utils import get_tokenizer
52
53
54
from sglang.srt.managers.data_parallel_controller import (
    run_data_parallel_controller_process,
)
55
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
56
57
58
59
60
from sglang.srt.managers.io_struct import (
    EmbeddingReqInput,
    GenerateReqInput,
    UpdateWeightReqInput,
)
61
from sglang.srt.managers.scheduler import run_scheduler_process
Lianmin Zheng's avatar
Lianmin Zheng committed
62
from sglang.srt.managers.tokenizer_manager import TokenizerManager
Mingyi's avatar
Mingyi committed
63
from sglang.srt.openai_api.adapter import (
Liangsheng Yin's avatar
Liangsheng Yin committed
64
    load_chat_template_for_openai_api,
65
    v1_batches,
66
    v1_cancel_batch,
Liangsheng Yin's avatar
Liangsheng Yin committed
67
68
    v1_chat_completions,
    v1_completions,
69
    v1_delete_file,
Ying Sheng's avatar
Ying Sheng committed
70
    v1_embeddings,
71
72
73
74
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
Liangsheng Yin's avatar
Liangsheng Yin committed
75
)
Mingyi's avatar
Mingyi committed
76
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Mingyi's avatar
Mingyi committed
77
from sglang.srt.server_args import PortArgs, ServerArgs
Lianmin Zheng's avatar
Lianmin Zheng committed
78
from sglang.srt.utils import (
79
    add_api_key_middleware,
Lianmin Zheng's avatar
Lianmin Zheng committed
80
    assert_pkg_version,
81
    configure_logger,
82
    is_port_available,
83
    kill_child_process,
84
    maybe_set_triton_cache_manager,
85
    prepare_model_and_tokenizer,
86
    set_ulimit,
Lianmin Zheng's avatar
Lianmin Zheng committed
87
)
88
89
from sglang.utils import get_exception_traceback

90
91
logger = logging.getLogger(__name__)

92
93
94
95
# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory

Lianmin Zheng's avatar
Lianmin Zheng committed
96
97
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Lianmin Zheng's avatar
Lianmin Zheng committed
98

Lianmin Zheng's avatar
Lianmin Zheng committed
99
app = FastAPI()
100
tokenizer_manager: TokenizerManager = None
Lianmin Zheng's avatar
Lianmin Zheng committed
101

102
103
104
105
106
107
108
109
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Lianmin Zheng's avatar
Lianmin Zheng committed
110

111
112
113
114
115
116
117
118
119
@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""
120
121
122
123
124
125
126
127
128
129
130
131
    gri = GenerateReqInput(
        text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
    )
    try:
        async for _ in tokenizer_manager.generate_request(gri, request):
            break
        return Response(status_code=200)
    except Exception as e:
        logger.exception(e)
        return Response(status_code=503)


Lianmin Zheng's avatar
Lianmin Zheng committed
132
133
@app.get("/get_model_info")
async def get_model_info():
134
    """Get the model information."""
Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
    result = {
        "model_path": tokenizer_manager.model_path,
137
        "is_generation": tokenizer_manager.is_generation,
Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
140
    }
    return result

Cody Yu's avatar
Cody Yu committed
141

Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
@app.get("/get_server_args")
async def get_server_args():
144
    """Get the server arguments."""
Liangsheng Yin's avatar
Liangsheng Yin committed
145
146
147
    return dataclasses.asdict(tokenizer_manager.server_args)


Lianmin Zheng's avatar
Lianmin Zheng committed
148
@app.post("/flush_cache")
Liangsheng Yin's avatar
Liangsheng Yin committed
149
async def flush_cache():
150
    """Flush the radix cache."""
151
    tokenizer_manager.flush_cache()
Liangsheng Yin's avatar
Liangsheng Yin committed
152
    return Response(
153
154
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
Liangsheng Yin's avatar
Liangsheng Yin committed
155
156
157
158
        status_code=200,
    )


159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
    """Start profiling."""
    tokenizer_manager.start_profile()
    return Response(
        content="Start profiling.\n",
        status_code=200,
    )


@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
    """Stop profiling."""
    tokenizer_manager.stop_profile()
    return Response(
        content="Stop profiling. This will take some time.\n",
        status_code=200,
    )


181
182
183
184
185
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
    """Get the memory pool size in number of tokens"""
    try:
        ret = await tokenizer_manager.get_memory_pool_size()
186
187

        return ret
188
    except Exception as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
189
        return ORJSONResponse(
190
191
192
193
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


194
195
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):
196
    """Update the weights inplace without re-launching the server."""
197
    success, message = await tokenizer_manager.update_weights(obj, request)
Lianmin Zheng's avatar
Lianmin Zheng committed
198
    content = {"success": success, "message": message}
199
    if success:
200
        return ORJSONResponse(
201
202
203
204
            content,
            status_code=HTTPStatus.OK,
        )
    else:
205
        return ORJSONResponse(
206
207
208
209
210
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


211
# fastapi implicitly converts json in the request to obj (dataclass)
212
async def generate_request(obj: GenerateReqInput, request: Request):
Mingyi's avatar
Mingyi committed
213
    """Handle a generate request."""
Lianmin Zheng's avatar
Lianmin Zheng committed
214
    if obj.stream:
215

216
        async def stream_results() -> AsyncIterator[bytes]:
217
218
            try:
                async for out in tokenizer_manager.generate_request(obj, request):
219
220
221
                    yield b"data: " + orjson.dumps(
                        out, option=orjson.OPT_NON_STR_KEYS
                    ) + b"\n\n"
222
223
            except ValueError as e:
                out = {"error": {"message": str(e)}}
224
225
226
227
                yield b"data: " + orjson.dumps(
                    out, option=orjson.OPT_NON_STR_KEYS
                ) + b"\n\n"
            yield b"data: [DONE]\n\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
228

229
230
231
232
233
        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
            background=tokenizer_manager.create_abort_task(obj),
        )
234
235
236
237
238
    else:
        try:
            ret = await tokenizer_manager.generate_request(obj, request).__anext__()
            return ret
        except ValueError as e:
239
            return ORJSONResponse(
240
241
242
                {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
243

Ying Sheng's avatar
Ying Sheng committed
244
245
246
app.post("/generate")(generate_request)
app.put("/generate")(generate_request)

Lianmin Zheng's avatar
Lianmin Zheng committed
247

248
249
250
251
252
253
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
254
        return ORJSONResponse(
255
256
257
258
259
260
261
262
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


app.post("/encode")(encode_request)
app.put("/encode")(encode_request)


263
async def classify_request(obj: EmbeddingReqInput, request: Request):
264
    """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
265
266
267
268
    try:
        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
        return ret
    except ValueError as e:
269
        return ORJSONResponse(
270
271
272
273
            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
        )


274
275
app.post("/classify")(classify_request)
app.put("/classify")(classify_request)
276
277


Lianmin Zheng's avatar
Lianmin Zheng committed
278
@app.post("/v1/completions")
279
280
async def openai_v1_completions(raw_request: Request):
    return await v1_completions(tokenizer_manager, raw_request)
Lianmin Zheng's avatar
Lianmin Zheng committed
281
282


Cody Yu's avatar
Cody Yu committed
283
@app.post("/v1/chat/completions")
284
285
async def openai_v1_chat_completions(raw_request: Request):
    return await v1_chat_completions(tokenizer_manager, raw_request)
286

Lianmin Zheng's avatar
Lianmin Zheng committed
287

288
@app.post("/v1/embeddings", response_class=ORJSONResponse)
Ying Sheng's avatar
Ying Sheng committed
289
290
291
292
293
async def openai_v1_embeddings(raw_request: Request):
    response = await v1_embeddings(tokenizer_manager, raw_request)
    return response


294
@app.get("/v1/models", response_class=ORJSONResponse)
295
296
297
298
299
300
301
302
303
def available_models():
    """Show available models."""
    served_model_names = [tokenizer_manager.served_model_name]
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


304
305
306
307
308
309
310
@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
        file, purpose, tokenizer_manager.server_args.file_storage_pth
    )


311
312
313
314
315
316
@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


317
318
319
320
321
@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
    return await v1_batches(tokenizer_manager, raw_request)


322
323
324
325
326
327
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
    return await v1_cancel_batch(tokenizer_manager, batch_id)


328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


345
def launch_engine(
zhyncs's avatar
zhyncs committed
346
347
    server_args: ServerArgs,
):
348
349
350
351
    """
    Launch the Tokenizer Manager in the main process, the Scheduler in a subprocess, and the Detokenizer Manager in another subprocess.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
352
353
    global tokenizer_manager

354
    # Configure global environment
355
    configure_logger(server_args)
356
357
    server_args.check_server_args()
    _set_envs_and_config(server_args)
358

359
    # Allocate ports for inter-process communications
360
    port_args = PortArgs.init_new(server_args)
361
    logger.info(f"{server_args=}")
Lianmin Zheng's avatar
Lianmin Zheng committed
362

363
364
365
366
    # If using model from www.modelscope.cn, first download the model.
    server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
        server_args.model_path, server_args.tokenizer_path
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    if server_args.dp_size == 1:
        # Launch tensor parallel scheduler processes
        scheduler_procs = []
        scheduler_pipe_readers = []
        tp_size_per_node = server_args.tp_size // server_args.nnodes
        tp_rank_range = range(
            tp_size_per_node * server_args.node_rank,
            tp_size_per_node * (server_args.node_rank + 1),
        )
        for tp_rank in tp_rank_range:
            reader, writer = mp.Pipe(duplex=False)
            gpu_id = tp_rank % tp_size_per_node
            proc = mp.Process(
                target=run_scheduler_process,
                args=(server_args, port_args, gpu_id, tp_rank, None, writer),
            )
            proc.start()
            scheduler_procs.append(proc)
            scheduler_pipe_readers.append(reader)

        if server_args.node_rank >= 1:
            # For other nodes, they do not need to run tokenizer or detokenizer,
            # so they can just wait here.
            while True:
                pass
    else:
        # Launch the data parallel controller
395
        reader, writer = mp.Pipe(duplex=False)
396
        scheduler_pipe_readers = [reader]
397
        proc = mp.Process(
398
399
            target=run_data_parallel_controller_process,
            args=(server_args, port_args, writer),
400
401
        )
        proc.start()
402

403
404
405
    # Launch detokenizer process
    detoken_proc = mp.Process(
        target=run_detokenizer_process,
Lianmin Zheng's avatar
Lianmin Zheng committed
406
407
408
409
410
        args=(
            server_args,
            port_args,
        ),
    )
411
    detoken_proc.start()
Lianmin Zheng's avatar
Lianmin Zheng committed
412

413
    # Launch tokenizer process
414
415
416
417
    tokenizer_manager = TokenizerManager(server_args, port_args)
    if server_args.chat_template:
        load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)

418
419
420
    # Wait for model to finish loading
    for i in range(len(scheduler_pipe_readers)):
        scheduler_pipe_readers[i].recv()
Lianmin Zheng's avatar
Lianmin Zheng committed
421

422
423
424
425
426
427
428
429
430
431
432
433
def add_prometheus_middleware(app: FastAPI):
    # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
    from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess

    registry = CollectorRegistry()
    multiprocess.MultiProcessCollector(registry)
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))

    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
    app.routes.append(metrics_route)

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456

def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[mp.connection.Connection] = None,
):
    """
    Launch SRT (SGLang Runtime) Server

    The SRT server consists of an HTTP server and the SRT engine.

    1. HTTP server: A FastAPI server that routes requests to the engine.
    2. SRT engine:
        1. Tokenizer Manager: Tokenizes the requests and sends them to the scheduler.
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. Detokenizer Manager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
    1. The HTTP server and Tokenizer Manager both run in the main process.
    2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
    """

    launch_engine(server_args=server_args)

457
458
459
    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)
460

461
462
463
464
465
    # add prometheus middleware
    if server_args.enable_metrics:
        _set_prometheus_env()
        add_prometheus_middleware(app)

466
    # Send a warmup request
zhyncs's avatar
zhyncs committed
467
    t = threading.Thread(
Lianmin Zheng's avatar
Lianmin Zheng committed
468
        target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
zhyncs's avatar
zhyncs committed
469
    )
470
    t.start()
471

472
    try:
473
        # Listen for HTTP requests
474
475
476
        LOGGING_CONFIG["formatters"]["default"][
            "fmt"
        ] = "[%(asctime)s] %(levelprefix)s %(message)s"
477
        LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
478
479
480
        LOGGING_CONFIG["formatters"]["access"][
            "fmt"
        ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
481
        LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
Lianmin Zheng's avatar
Lianmin Zheng committed
482
483
484
485
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
486
            log_level=server_args.log_level_http or server_args.log_level,
Lianmin Zheng's avatar
Lianmin Zheng committed
487
488
489
            timeout_keep_alive=5,
            loop="uvloop",
        )
490
491
    finally:
        t.join()
Lianmin Zheng's avatar
Lianmin Zheng committed
492

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def _set_prometheus_env():
    # Set prometheus multiprocess directory
    # sglang uses prometheus multiprocess mode
    # we need to set this before importing prometheus_client
    # https://prometheus.github.io/client_python/multiprocess/
    global prometheus_multiproc_dir
    if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
        logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.")
        prometheus_multiproc_dir = tempfile.TemporaryDirectory(
            dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
        )
    else:
        prometheus_multiproc_dir = tempfile.TemporaryDirectory()
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
    logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
Lianmin Zheng's avatar
Lianmin Zheng committed
508

509
510
511
512
513
514
def _set_envs_and_config(server_args: ServerArgs):
    # Set global environments
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
    os.environ["NCCL_CUMEM_ENABLE"] = "0"
    os.environ["NCCL_NVLS_ENABLE"] = "0"
    os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
515
    os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
516
517
518
519
520
521
522
523
524
525

    # Set ulimit
    set_ulimit()

    # Fix triton bugs
    if server_args.tp_size * server_args.dp_size > 1:
        # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
        maybe_set_triton_cache_manager()

    # Check flashinfer version
526
    if server_args.attention_backend == "flashinfer":
527
528
        assert_pkg_version(
            "flashinfer",
529
            "0.1.6",
530
531
532
533
534
            "Please uninstall the old version and "
            "reinstall the latest version by following the instructions "
            "at https://docs.flashinfer.ai/installation.html.",
        )

535
    mp.set_start_method("spawn", force=True)
536

537

Lianmin Zheng's avatar
Lianmin Zheng committed
538
def _wait_and_warmup(server_args, pipe_finish_writer):
Mingyi's avatar
Mingyi committed
539
540
541
    headers = {}
    url = server_args.url()
    if server_args.api_key:
542
        headers["Authorization"] = f"Bearer {server_args.api_key}"
Mingyi's avatar
Mingyi committed
543
544

    # Wait until the server is launched
545
    success = False
Mingyi's avatar
Mingyi committed
546
    for _ in range(120):
547
        time.sleep(1)
Mingyi's avatar
Mingyi committed
548
        try:
549
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
550
            assert res.status_code == 200, f"{res=}, {res.text=}"
551
            success = True
Mingyi's avatar
Mingyi committed
552
            break
553
        except (AssertionError, requests.exceptions.RequestException):
554
            last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
555
556
            pass

557
558
559
    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
560
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
561
        kill_child_process(include_self=True)
562
        return
563

564
    model_info = res.json()
Mingyi's avatar
Mingyi committed
565
    # Send a warmup request
566
    request_name = "/generate" if model_info["is_generation"] else "/encode"
Ying Sheng's avatar
Ying Sheng committed
567
    max_new_tokens = 8 if model_info["is_generation"] else 1
568
569
570
571
572
573
574
575
576
577
578
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

Mingyi's avatar
Mingyi committed
579
580
581
    try:
        for _ in range(server_args.dp_size):
            res = requests.post(
582
                url + request_name,
583
                json=json_data,
Mingyi's avatar
Mingyi committed
584
585
586
                headers=headers,
                timeout=600,
            )
587
            assert res.status_code == 200, f"{res}"
588
    except Exception:
589
        last_traceback = get_exception_traceback()
Mingyi's avatar
Mingyi committed
590
        if pipe_finish_writer is not None:
591
            pipe_finish_writer.send(last_traceback)
592
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
Lianmin Zheng's avatar
Lianmin Zheng committed
593
        kill_child_process(include_self=True)
594
        return
Mingyi's avatar
Mingyi committed
595

596
597
    # logger.info(f"{res.json()=}")

Mingyi's avatar
Mingyi committed
598
599
    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
600
        pipe_finish_writer.send("ready")
Mingyi's avatar
Mingyi committed
601
602


Lianmin Zheng's avatar
Lianmin Zheng committed
603
class Runtime:
Lianmin Zheng's avatar
Lianmin Zheng committed
604
605
606
607
608
609
    """
    A wrapper for the server.
    This is used for launching the server in a python program without
    using the commond line interface.
    """

Lianmin Zheng's avatar
Lianmin Zheng committed
610
611
    def __init__(
        self,
612
        log_level: str = "error",
Lianmin Zheng's avatar
Lianmin Zheng committed
613
614
        *args,
        **kwargs,
Lianmin Zheng's avatar
Lianmin Zheng committed
615
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
616
        """See the arguments in server_args.py::ServerArgs"""
617
        self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
Lianmin Zheng's avatar
Lianmin Zheng committed
618

619
620
621
        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)

Lianmin Zheng's avatar
Lianmin Zheng committed
622
        # Pre-allocate ports
623
624
625
626
627
        for port in range(10000, 40000):
            if is_port_available(port):
                break
            port += 1
        self.server_args.port = port
Lianmin Zheng's avatar
Lianmin Zheng committed
628

Ying Sheng's avatar
Ying Sheng committed
629
        self.url = self.server_args.url()
630
        self.generate_url = self.url + "/generate"
Lianmin Zheng's avatar
Lianmin Zheng committed
631

632
        # NOTE: We store pid instead of proc to fix some issues during __delete__
Lianmin Zheng's avatar
Lianmin Zheng committed
633
634
        self.pid = None
        pipe_reader, pipe_writer = mp.Pipe(duplex=False)
635

Yuanhan Zhang's avatar
Yuanhan Zhang committed
636
637
        proc = mp.Process(
            target=launch_server,
Lianmin Zheng's avatar
Lianmin Zheng committed
638
            args=(self.server_args, pipe_writer),
Yuanhan Zhang's avatar
Yuanhan Zhang committed
639
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
640
        proc.start()
641
        pipe_writer.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
642
643
        self.pid = proc.pid

644
645
646
647
648
        try:
            init_state = pipe_reader.recv()
        except EOFError:
            init_state = ""

649
        if init_state != "ready":
Lianmin Zheng's avatar
Lianmin Zheng committed
650
            self.shutdown()
Yuanhan Zhang's avatar
Yuanhan Zhang committed
651
652
653
            raise RuntimeError(
                "Initialization failed. Please see the error messages above."
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
654
655
656
657
658

        self.endpoint = RuntimeEndpoint(self.url)

    def shutdown(self):
        if self.pid is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
659
            kill_child_process(self.pid, include_self=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
660
661
            self.pid = None

662
663
664
    def cache_prefix(self, prefix: str):
        self.endpoint.cache_prefix(prefix)

Ying Sheng's avatar
Ying Sheng committed
665
666
667
668
669
670
671
    def get_tokenizer(self):
        return get_tokenizer(
            self.server_args.tokenizer_path,
            tokenizer_mode=self.server_args.tokenizer_mode,
            trust_remote_code=self.server_args.trust_remote_code,
        )

672
    async def async_generate(
Ying Sheng's avatar
Ying Sheng committed
673
674
        self,
        prompt: str,
675
        sampling_params: Optional[Dict] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
676
    ):
677
678
679
680
681
682
683
684
685
686
687
688
        if self.server_args.skip_tokenizer_init:
            json_data = {
                "input_ids": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
        else:
            json_data = {
                "text": prompt,
                "sampling_params": sampling_params,
                "stream": True,
            }
Ying Sheng's avatar
Ying Sheng committed
689
690
691
692
693
694
695
696
697
698
699
        pos = 0

        timeout = aiohttp.ClientTimeout(total=3 * 3600)
        async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
            async with session.post(self.generate_url, json=json_data) as response:
                async for chunk, _ in response.content.iter_chunks():
                    chunk = chunk.decode("utf-8")
                    if chunk and chunk.startswith("data:"):
                        if chunk == "data: [DONE]\n\n":
                            break
                        data = json.loads(chunk[5:].strip("\n"))
Lianmin Zheng's avatar
Lianmin Zheng committed
700
                        if "text" in data:
701
702
703
704
705
706
                            cur = data["text"][pos:]
                            if cur:
                                yield cur
                            pos += len(cur)
                        else:
                            yield data
Ying Sheng's avatar
Ying Sheng committed
707

708
709
710
711
    add_request = async_generate

    def generate(
        self,
712
        prompt: Union[str, List[str]],
713
714
        sampling_params: Optional[Dict] = None,
        return_logprob: Optional[Union[List[bool], bool]] = False,
715
        logprob_start_len: Optional[Union[List[int], int]] = None,
716
        top_logprobs_num: Optional[Union[List[int], int]] = None,
717
        lora_path: Optional[List[Optional[str]]] = None,
718
719
720
721
722
    ):
        json_data = {
            "text": prompt,
            "sampling_params": sampling_params,
            "return_logprob": return_logprob,
723
            "logprob_start_len": logprob_start_len,
724
            "top_logprobs_num": top_logprobs_num,
725
            "lora_path": lora_path,
726
        }
727
        assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
728
729
730
731
732
733
        response = requests.post(
            self.url + "/generate",
            json=json_data,
        )
        return json.dumps(response.json())

734
735
    def encode(
        self,
736
        prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
737
    ):
738
739
        json_data = {"text": prompt}
        response = requests.post(self.url + "/encode", json=json_data)
740
741
        return json.dumps(response.json())

Lianmin Zheng's avatar
Lianmin Zheng committed
742
    def __del__(self):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
743
        self.shutdown()
744
745


Lianmin Zheng's avatar
Lianmin Zheng committed
746
747
748
749
STREAM_END_SYMBOL = b"data: [DONE]"
STREAM_CHUNK_START_SYMBOL = b"data:"


750
751
752
753
754
755
756
757
758
759
760
761
class Engine:
    """
    SRT Engine without an HTTP server layer.

    This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
    launching the HTTP server adds unnecessary complexity or overhead,
    """

    def __init__(self, *args, **kwargs):

        # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
        atexit.register(self.shutdown)
Chayenne's avatar
Chayenne committed
762
763
764
765
766
767
        
        # runtime server default log level is log
        # offline engine works in scripts, so we set it to error

        if 'log_level' not in kwargs:
            kwargs['log_level'] = 'error'
768
769
770
771
772
773

        server_args = ServerArgs(*args, **kwargs)
        launch_engine(server_args=server_args)

    def generate(
        self,
774
775
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
776
        sampling_params: Optional[Dict] = None,
777
778
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
779
780
781
782
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
783
        stream: bool = False,
784
785
786
    ):
        obj = GenerateReqInput(
            text=prompt,
787
            input_ids=input_ids,
788
789
790
791
792
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
793
            stream=stream,
794
795
        )

796
797
        # get the current event loop
        loop = asyncio.get_event_loop()
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
        ret = loop.run_until_complete(generate_request(obj, None))

        if stream is True:

            def generator_wrapper():
                offset = 0
                loop = asyncio.get_event_loop()
                generator = ret.body_iterator
                while True:
                    chunk = loop.run_until_complete(generator.__anext__())

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
            # however, it allows to wrap the generator as a subfunction and return
            return generator_wrapper()
        else:
            return ret

    async def async_generate(
        self,
825
826
        # The input prompt. It can be a single prompt or a batch of prompts.
        prompt: Optional[Union[List[str], str]] = None,
827
        sampling_params: Optional[Dict] = None,
828
829
        # The token ids for text; one can either specify text or input_ids.
        input_ids: Optional[Union[List[List[int]], List[int]]] = None,
830
831
832
833
834
835
836
837
        return_logprob: Optional[Union[List[bool], bool]] = False,
        logprob_start_len: Optional[Union[List[int], int]] = None,
        top_logprobs_num: Optional[Union[List[int], int]] = None,
        lora_path: Optional[List[Optional[str]]] = None,
        stream: bool = False,
    ):
        obj = GenerateReqInput(
            text=prompt,
838
            input_ids=input_ids,
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
            sampling_params=sampling_params,
            return_logprob=return_logprob,
            logprob_start_len=logprob_start_len,
            top_logprobs_num=top_logprobs_num,
            lora_path=lora_path,
            stream=stream,
        )

        ret = await generate_request(obj, None)

        if stream is True:
            generator = ret.body_iterator

            async def generator_wrapper():

                offset = 0

                while True:
                    chunk = await generator.__anext__()

                    if chunk.startswith(STREAM_END_SYMBOL):
                        break
                    else:
                        data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
                        data["text"] = data["text"][offset:]
                        offset += len(data["text"])
                        yield data

            return generator_wrapper()
        else:
            return ret
870
871

    def shutdown(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
872
        kill_child_process()
873

874
875
876
877
878
879
880
881
    def get_tokenizer(self):
        global tokenizer_manager

        if tokenizer_manager is None:
            raise ReferenceError("Tokenizer Manager is not initialized.")
        else:
            return tokenizer_manager.tokenizer

882
    # TODO (ByronHsu): encode