http_server.py 24.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The entry point of inference server. (SRT = SGLang Runtime)

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
17
This file implements HTTP APIs for the inference engine via fastapi.
18
19
20
21
22
23
24
25
26
27
"""

import asyncio
import dataclasses
import logging
import multiprocessing as multiprocessing
import os
import threading
import time
from http import HTTPStatus
28
from typing import AsyncIterator, Callable, Dict, Optional
29
30
31
32

# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)

33
34
35
from contextlib import asynccontextmanager

import numpy as np
36
37
38
39
40
41
42
43
44
import orjson
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse

from sglang.srt.entrypoints.engine import _launch_subprocesses
YAMY's avatar
YAMY committed
45
from sglang.srt.function_call_parser import FunctionCallParser
46
47
48
49
50
51
52
53
from sglang.srt.managers.io_struct import (
    CloseSessionReqInput,
    ConfigureLoggingReq,
    EmbeddingReqInput,
    GenerateReqInput,
    GetWeightsByNameReqInput,
    InitWeightsUpdateGroupReqInput,
    OpenSessionReqInput,
54
    ParseFunctionCallReq,
55
    ProfileReqInput,
56
57
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
Xihuai Wang's avatar
Xihuai Wang committed
58
    SeparateReasoningReqInput,
59
    SetInternalStateReq,
60
61
    UpdateWeightFromDiskReqInput,
    UpdateWeightsFromDistributedReqInput,
62
    VertexGenerateReqInput,
63
)
64
from sglang.srt.managers.tokenizer_manager import TokenizerManager
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from sglang.srt.metrics.func_timer import enable_func_timer
from sglang.srt.openai_api.adapter import (
    v1_batches,
    v1_cancel_batch,
    v1_chat_completions,
    v1_completions,
    v1_delete_file,
    v1_embeddings,
    v1_files_create,
    v1_retrieve_batch,
    v1_retrieve_file,
    v1_retrieve_file_content,
)
from sglang.srt.openai_api.protocol import ModelCard, ModelList
Xihuai Wang's avatar
Xihuai Wang committed
79
from sglang.srt.reasoning_parser import ReasoningParser
80
81
82
83
84
85
86
87
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
    add_api_key_middleware,
    add_prometheus_middleware,
    delete_directory,
    kill_process_tree,
    set_uvicorn_logging_configs,
)
88
from sglang.srt.warmup import execute_warmups
89
90
91
92
93
94
95
96
97
98
from sglang.utils import get_exception_traceback
from sglang.version import __version__

logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())


# Store global states
@dataclasses.dataclass
class _GlobalState:
99
    tokenizer_manager: TokenizerManager
100
101
102
103
104
105
106
107
108
109
110
    scheduler_info: Dict


_global_state: Optional[_GlobalState] = None


def set_global_state(global_state: _GlobalState):
    global _global_state
    _global_state = global_state


111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
@asynccontextmanager
async def lifespan(fast_api_app: FastAPI):
    server_args: ServerArgs = fast_api_app.server_args
    if server_args.warmups is not None:
        await execute_warmups(
            server_args.warmups.split(","), _global_state.tokenizer_manager
        )
        logger.info("Warmup ended")

    warmup_thread = getattr(fast_api_app, "warmup_thread", None)
    if warmup_thread is not None:
        warmup_thread.start()
    yield


# Fast API
app = FastAPI(lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))


139
140
141
142
143
144
145
146
147
148
149
150
151
##### Native API endpoints #####


@app.get("/health")
async def health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
    """Check the health of the inference server by generating one token."""

152
153
    sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
    rid = f"HEALTH_CHECK_{time.time()}"
154

155
156
157
    if _global_state.tokenizer_manager.is_image_gen:
        raise NotImplementedError()
    elif _global_state.tokenizer_manager.is_generation:
158
        gri = GenerateReqInput(
159
160
161
162
            rid=rid,
            input_ids=[0],
            sampling_params=sampling_params,
            log_metrics=False,
163
164
165
        )
    else:
        gri = EmbeddingReqInput(
166
            rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
167
168
        )

169
    async def gen():
170
        async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
171
            break
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    tic = time.time()
    task = asyncio.create_task(gen())
    while time.time() < tic + HEALTH_CHECK_TIMEOUT:
        await asyncio.sleep(1)
        if _global_state.tokenizer_manager.last_receive_tstamp > tic:
            task.cancel()
            _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
            return Response(status_code=200)

    task.cancel()
    tic_time = time.strftime("%H:%M:%S", time.localtime(tic))
    last_receive_time = time.strftime(
        "%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp)
    )
    logger.error(
        f"Health check failed. Server couldn't get a response from detokenizer for last "
        f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. "
        f"last_heartbeat time: {last_receive_time}"
    )
    _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
    return Response(status_code=503)
194
195
196
197
198
199


@app.get("/get_model_info")
async def get_model_info():
    """Get the model information."""
    result = {
200
201
202
        "model_path": _global_state.tokenizer_manager.model_path,
        "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
        "is_generation": _global_state.tokenizer_manager.is_generation,
203
204
205
206
207
208
    }
    return result


@app.get("/get_server_info")
async def get_server_info():
209
    internal_states = await _global_state.tokenizer_manager.get_internal_state()
210
    return {
211
        **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
212
        **_global_state.scheduler_info,
213
        **internal_states,
214
215
216
217
        "version": __version__,
    }


218
219
220
221
222
223
@app.api_route("/set_internal_state", methods=["POST", "PUT"])
async def set_internal_state(obj: SetInternalStateReq, request: Request):
    res = await _global_state.tokenizer_manager.set_internal_state(obj)
    return res


224
225
226
227
228
229
230
231
# fastapi implicitly converts json in the request to obj (dataclass)
@app.api_route("/generate", methods=["POST", "PUT"])
async def generate_request(obj: GenerateReqInput, request: Request):
    """Handle a generate request."""
    if obj.stream:

        async def stream_results() -> AsyncIterator[bytes]:
            try:
232
                async for out in _global_state.tokenizer_manager.generate_request(
233
234
235
236
237
238
239
                    obj, request
                ):
                    yield b"data: " + orjson.dumps(
                        out, option=orjson.OPT_NON_STR_KEYS
                    ) + b"\n\n"
            except ValueError as e:
                out = {"error": {"message": str(e)}}
240
                logger.error(f"Error: {e}")
241
242
243
244
245
246
247
248
                yield b"data: " + orjson.dumps(
                    out, option=orjson.OPT_NON_STR_KEYS
                ) + b"\n\n"
            yield b"data: [DONE]\n\n"

        return StreamingResponse(
            stream_results(),
            media_type="text/event-stream",
249
            background=_global_state.tokenizer_manager.create_abort_task(obj),
250
251
252
        )
    else:
        try:
253
            ret = await _global_state.tokenizer_manager.generate_request(
254
255
256
257
258
259
260
261
262
263
264
265
                obj, request
            ).__anext__()
            return ret
        except ValueError as e:
            logger.error(f"Error: {e}")
            return _create_error_response(e)


@app.api_route("/encode", methods=["POST", "PUT"])
async def encode_request(obj: EmbeddingReqInput, request: Request):
    """Handle an embedding request."""
    try:
266
        ret = await _global_state.tokenizer_manager.generate_request(
267
268
269
270
271
272
273
274
275
276
277
            obj, request
        ).__anext__()
        return ret
    except ValueError as e:
        return _create_error_response(e)


@app.api_route("/classify", methods=["POST", "PUT"])
async def classify_request(obj: EmbeddingReqInput, request: Request):
    """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
    try:
278
        ret = await _global_state.tokenizer_manager.generate_request(
279
280
281
282
283
284
285
            obj, request
        ).__anext__()
        return ret
    except ValueError as e:
        return _create_error_response(e)


286
@app.api_route("/flush_cache", methods=["GET", "POST"])
287
288
async def flush_cache():
    """Flush the radix cache."""
289
    _global_state.tokenizer_manager.flush_cache()
290
291
292
293
294
295
296
297
    return Response(
        content="Cache flushed.\nPlease check backend logs for more details. "
        "(When there are running or waiting requests, the operation will not be performed.)\n",
        status_code=200,
    )


@app.api_route("/start_profile", methods=["GET", "POST"])
298
async def start_profile_async(obj: Optional[ProfileReqInput] = None):
299
    """Start profiling."""
300
301
302
303
304
305
    if obj is None:
        obj = ProfileReqInput()

    await _global_state.tokenizer_manager.start_profile(
        obj.output_dir, obj.num_steps, obj.activities
    )
306
307
308
309
310
311
312
313
314
    return Response(
        content="Start profiling.\n",
        status_code=200,
    )


@app.api_route("/stop_profile", methods=["GET", "POST"])
async def stop_profile_async():
    """Stop profiling."""
315
    _global_state.tokenizer_manager.stop_profile()
316
317
318
319
320
321
322
323
    return Response(
        content="Stop profiling. This will take some time.\n",
        status_code=200,
    )


@app.post("/update_weights_from_disk")
async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
324
325
326
    """Update the weights from disk inplace without re-launching the server."""
    success, message, num_paused_requests = (
        await _global_state.tokenizer_manager.update_weights_from_disk(obj, request)
327
    )
328
329
330
331
332
    content = {
        "success": success,
        "message": message,
        "num_paused_requests": num_paused_requests,
    }
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    if success:
        return ORJSONResponse(
            content,
            status_code=HTTPStatus.OK,
        )
    else:
        return ORJSONResponse(
            content,
            status_code=HTTPStatus.BAD_REQUEST,
        )


@app.post("/init_weights_update_group")
async def init_weights_update_group(
    obj: InitWeightsUpdateGroupReqInput, request: Request
):
    """Initialize the parameter update group."""
350
    success, message = await _global_state.tokenizer_manager.init_weights_update_group(
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        obj, request
    )
    content = {"success": success, "message": message}
    if success:
        return ORJSONResponse(content, status_code=200)
    else:
        return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)


@app.post("/update_weights_from_distributed")
async def update_weights_from_distributed(
    obj: UpdateWeightsFromDistributedReqInput, request: Request
):
    """Update model parameter from distributed online."""
365
366
367
368
    success, message = (
        await _global_state.tokenizer_manager.update_weights_from_distributed(
            obj, request
        )
369
370
371
372
373
374
375
376
377
378
379
380
    )
    content = {"success": success, "message": message}
    if success:
        return ORJSONResponse(content, status_code=200)
    else:
        return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)


@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
    """Get model parameter by name."""
    try:
381
        ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)
382
383
384
385
386
387
388
389
390
391
392
393
        if ret is None:
            return _create_error_response("Get parameter by name failed")
        else:
            return ORJSONResponse(ret, status_code=200)
    except Exception as e:
        return _create_error_response(e)


@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
async def release_memory_occupation(
    obj: ReleaseMemoryOccupationReqInput, request: Request
):
394
    """Release GPU memory occupation temporarily."""
395
    try:
396
        await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
397
398
399
400
401
402
403
404
    except Exception as e:
        return _create_error_response(e)


@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
async def resume_memory_occupation(
    obj: ResumeMemoryOccupationReqInput, request: Request
):
405
    """Resume GPU memory occupation."""
406
    try:
407
        await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
408
409
410
411
412
413
414
415
    except Exception as e:
        return _create_error_response(e)


@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
    """Open a session, and return its unique session id."""
    try:
416
        session_id = await _global_state.tokenizer_manager.open_session(obj, request)
417
418
419
420
421
422
423
424
425
426
427
        if session_id is None:
            raise Exception(
                "Failed to open the session. Check if a session with the same id is still open."
            )
        return session_id
    except Exception as e:
        return _create_error_response(e)


@app.api_route("/close_session", methods=["GET", "POST"])
async def close_session(obj: CloseSessionReqInput, request: Request):
428
    """Close the session."""
429
    try:
430
        await _global_state.tokenizer_manager.close_session(obj, request)
431
432
433
434
435
436
437
        return Response(status_code=200)
    except Exception as e:
        return _create_error_response(e)


@app.api_route("/configure_logging", methods=["GET", "POST"])
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
438
    """Configure the request logging options."""
439
    _global_state.tokenizer_manager.configure_logging(obj)
440
441
442
    return Response(status_code=200)


443
444
@app.post("/parse_function_call")
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
YAMY's avatar
YAMY committed
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    """
    A native API endpoint to parse function calls from a text.
    """
    # 1) Initialize the parser based on the request body
    parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser)

    # 2) Call the non-stream parsing method (non-stream)
    normal_text, calls = parser.parse_non_stream(obj.text)

    # 3) Organize the response content
    response_data = {
        "normal_text": normal_text,
        "calls": [
            call.model_dump() for call in calls
        ],  # Convert pydantic objects to dictionaries
    }

    return ORJSONResponse(content=response_data, status_code=200)


Xihuai Wang's avatar
Xihuai Wang committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
@app.post("/separate_reasoning")
async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Request):
    """
    A native API endpoint to separate reasoning from a text.
    """
    # 1) Initialize the parser based on the request body
    parser = ReasoningParser(model_type=obj.reasoning_parser)

    # 2) Call the non-stream parsing method (non-stream)
    reasoning_text, normal_text = parser.parse_non_stream(obj.text)

    # 3) Organize the response content
    response_data = {
        "reasoning_text": reasoning_text,
        "text": normal_text,
    }

    return ORJSONResponse(content=response_data, status_code=200)


485
486
487
488
489
##### OpenAI-compatible API endpoints #####


@app.post("/v1/completions")
async def openai_v1_completions(raw_request: Request):
490
    return await v1_completions(_global_state.tokenizer_manager, raw_request)
491
492
493
494


@app.post("/v1/chat/completions")
async def openai_v1_chat_completions(raw_request: Request):
495
    return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
496
497
498
499


@app.post("/v1/embeddings", response_class=ORJSONResponse)
async def openai_v1_embeddings(raw_request: Request):
500
    response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
501
502
503
504
505
506
    return response


@app.get("/v1/models", response_class=ORJSONResponse)
def available_models():
    """Show available models."""
507
    served_model_names = [_global_state.tokenizer_manager.served_model_name]
508
509
510
511
512
513
514
515
516
    model_cards = []
    for served_model_name in served_model_names:
        model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
    return ModelList(data=model_cards)


@app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
    return await v1_files_create(
517
        file, purpose, _global_state.tokenizer_manager.server_args.file_storage_path
518
519
520
521
522
523
524
525
526
527
528
    )


@app.delete("/v1/files/{file_id}")
async def delete_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/delete
    return await v1_delete_file(file_id)


@app.post("/v1/batches")
async def openai_v1_batches(raw_request: Request):
529
    return await v1_batches(_global_state.tokenizer_manager, raw_request)
530
531
532
533
534


@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
    # https://platform.openai.com/docs/api-reference/batch/cancel
535
    return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554


@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
    return await v1_retrieve_batch(batch_id)


@app.get("/v1/files/{file_id}")
async def retrieve_file(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve
    return await v1_retrieve_file(file_id)


@app.get("/v1/files/{file_id}/content")
async def retrieve_file_content(file_id: str):
    # https://platform.openai.com/docs/api-reference/files/retrieve-contents
    return await v1_retrieve_file_content(file_id)


555
556
557
558
559
560
561
562
563
564
565
566
## SageMaker API
@app.get("/ping")
async def sagemaker_health() -> Response:
    """Check the health of the http server."""
    return Response(status_code=200)


@app.post("/invocations")
async def sagemaker_chat_completions(raw_request: Request):
    return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)


567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
## Vertex AI API
@app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate"))
async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request):
    if not vertex_req.instances:
        return []
    inputs = {}
    for input_key in ("text", "input_ids", "input_embeds"):
        if vertex_req.instances[0].get(input_key):
            inputs[input_key] = [
                instance.get(input_key) for instance in vertex_req.instances
            ]
            break
    image_data = [
        instance.get("image_data")
        for instance in vertex_req.instances
        if instance.get("image_data") is not None
    ] or None
    req = GenerateReqInput(
        **inputs,
        image_data=image_data,
        **(vertex_req.parameters or {}),
    )
    ret = await generate_request(req, raw_request)
    return ORJSONResponse({"predictions": ret})


593
594
595
596
597
598
599
600
601
def _create_error_response(e):
    return ORJSONResponse(
        {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
    )


def launch_server(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
602
    launch_callback: Optional[Callable[[], None]] = None,
603
604
605
606
607
608
609
610
):
    """
    Launch SRT (SGLang Runtime) Server.

    The SRT server consists of an HTTP server and an SRT engine.

    - HTTP server: A FastAPI server that routes requests to the engine.
    - The engine consists of three components:
611
        1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
612
613
614
615
        2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
        3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.

    Note:
616
    1. The HTTP server, Engine, and TokenizerManager both run in the main process.
617
    2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
618
    """
619
    tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
620
621
    set_global_state(
        _GlobalState(
622
            tokenizer_manager=tokenizer_manager,
623
624
625
626
627
628
629
630
631
632
633
634
635
            scheduler_info=scheduler_info,
        )
    )

    # Add api key authorization
    if server_args.api_key:
        add_api_key_middleware(app, server_args.api_key)

    # Add prometheus middleware
    if server_args.enable_metrics:
        add_prometheus_middleware(app)
        enable_func_timer()

636
637
638
    # Send a warmup request - we will create the thread launch it
    # in the lifespan after all other warmups have fired.
    warmup_thread = threading.Thread(
639
640
641
642
        target=_wait_and_warmup,
        args=(
            server_args,
            pipe_finish_writer,
643
            _global_state.tokenizer_manager.image_token_id,
644
            launch_callback,
645
646
        ),
    )
647
    app.warmup_thread = warmup_thread
648
649
650
651

    try:
        # Update logging configs
        set_uvicorn_logging_configs()
652
        app.server_args = server_args
653
654
655
656
657
658
659
660
661
662
        # Listen for HTTP requests
        uvicorn.run(
            app,
            host=server_args.host,
            port=server_args.port,
            log_level=server_args.log_level_http or server_args.log_level,
            timeout_keep_alive=5,
            loop="uvloop",
        )
    finally:
663
        warmup_thread.join()
664
665


666
667
668
669
670
671
def _wait_and_warmup(
    server_args: ServerArgs,
    pipe_finish_writer: Optional[multiprocessing.connection.Connection],
    image_token_text: str,
    launch_callback: Optional[Callable[[], None]] = None,
):
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    headers = {}
    url = server_args.url()
    if server_args.api_key:
        headers["Authorization"] = f"Bearer {server_args.api_key}"

    # Wait until the server is launched
    success = False
    for _ in range(120):
        time.sleep(1)
        try:
            res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
            assert res.status_code == 200, f"{res=}, {res.text=}"
            success = True
            break
        except (AssertionError, requests.exceptions.RequestException):
            last_traceback = get_exception_traceback()
            pass

    if not success:
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_process_tree(os.getpid())
        return

    model_info = res.json()

    # Send a warmup request
    request_name = "/generate" if model_info["is_generation"] else "/encode"
    max_new_tokens = 8 if model_info["is_generation"] else 1
    json_data = {
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": max_new_tokens,
        },
    }
    if server_args.skip_tokenizer_init:
        json_data["input_ids"] = [10, 11, 12]
    else:
        json_data["text"] = "The capital city of France is"

713
714
715
716
717
718
719
720
    # Debug dumping
    if server_args.debug_tensor_dump_input_file:
        json_data.pop("text", None)
        json_data["input_ids"] = np.load(
            server_args.debug_tensor_dump_input_file
        ).tolist()
        json_data["sampling_params"]["max_new_tokens"] = 0

721
    try:
722
        for i in range(server_args.dp_size):
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
            res = requests.post(
                url + request_name,
                json=json_data,
                headers=headers,
                timeout=600,
            )
            assert res.status_code == 200, f"{res}"
    except Exception:
        last_traceback = get_exception_traceback()
        if pipe_finish_writer is not None:
            pipe_finish_writer.send(last_traceback)
        logger.error(f"Initialization failed. warmup error: {last_traceback}")
        kill_process_tree(os.getpid())
        return

    # Debug print
    # logger.info(f"{res.json()=}")

    logger.info("The server is fired up and ready to roll!")
    if pipe_finish_writer is not None:
        pipe_finish_writer.send("ready")

    if server_args.delete_ckpt_after_loading:
        delete_directory(server_args.model_path)
747
748
749
750
751
752

    if server_args.debug_tensor_dump_input_file:
        kill_process_tree(os.getpid())

    if launch_callback is not None:
        launch_callback()