sglang_inc.py 7.61 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
5
# `dynamo-run out=sglang` runs this script
# Can also be used standalone: `python3 sglang_inc.py` - lots of optional cmd line params
6
7
8

import argparse
import asyncio
9
import json
10
import logging
11
import sys
12
from typing import Optional
13
14
15
16
17
18
19
20

import sglang
import uvloop
from sglang.srt.server_args import ServerArgs

from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker

21
# Only used if you run it manually from the command line
22
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
23
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
24

25
26
logging.basicConfig(level=logging.DEBUG)

27
28
29
30
31
32
33

class Config:
    """Command line parameters or defaults"""

    namespace: str
    component: str
    endpoint: str
34
35
    model_path: str
    model_name: Optional[str]
36
37
    base_gpu_id: int
    tensor_parallel_size: int
38
    kv_block_size: int
39
    context_length: int
40
41
42
    nnodes: int
    node_rank: int
    dist_init_addr: str
43
44
45
46
47
48
49
50
51
52
53
54
    extra_engine_args: str


class RequestHandler:
    """
    Request handler for the generate endpoint
    """

    def __init__(self, engine):
        self.engine_client = engine

    async def generate(self, request):
55
        sampling_params = {}
56
57
58
59
60
61
        if request["sampling_options"]["temperature"] is not None:
            sampling_params["temperature"] = request["sampling_options"]["temperature"]
        sampling_params = {
            # sglang defaults this to 128
            "max_new_tokens": request["stop_conditions"]["max_tokens"],
        }
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        num_output_tokens_so_far = 0
        gen = await self.engine_client.async_generate(
            input_ids=request["token_ids"], sampling_params=sampling_params, stream=True
        )
        async for res in gen:
            # res is a dict

            finish_reason = res["meta_info"]["finish_reason"]
            if finish_reason:
                # Don't forward the stop token
                out = {"token_ids": [], "finish_reason": finish_reason["type"]}
            else:
                next_total_toks = len(res["output_ids"])
                out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
            yield out
            num_output_tokens_so_far = next_total_toks


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
    await init(runtime, cmd_line_args())


async def init(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """

    arg_map = {
91
        "model_path": config.model_path,
92
93
94
95
        "skip_tokenizer_init": True,
        "tp_size": config.tensor_parallel_size,
        "base_gpu_id": config.base_gpu_id,
    }
96
97
98
99

    if config.kv_block_size:
        arg_map["page_size"] = config.kv_block_size

100
101
102
    if config.context_length:
        arg_map["context_length"] = config.context_length

103
104
105
106
107
108
109
    if config.dist_init_addr != "":
        arg_map["trust_remote_code"] = True
        arg_map["nnodes"] = config.nnodes
        arg_map["dist_init_addr"] = config.dist_init_addr
        # In practice this is always 0 because Dynamo only manages the leader
        arg_map["node_rank"] = config.node_rank

110
111
112
113
114
115
116
117
118
119
120
121
122
    if config.extra_engine_args != "":
        json_map = {}
        # extra_engine_args is a filename
        try:
            with open(config.extra_engine_args) as f:
                json_map = json.load(f)
        except FileNotFoundError:
            logging.error(f"File {config.extra_engine_args} not found.")
        except json.JSONDecodeError as e:
            logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
        logging.debug(f"Adding extra engine arguments: {json_map}")
        arg_map = {**arg_map, **json_map}  # json_map gets precedence

123
124
    # TODO fetch default SamplingParams from generation_config.json

125
126
127
    engine_args = ServerArgs(**arg_map)
    engine_client = sglang.Engine(server_args=engine_args)

128
129
130
131
132
133
134
135
    component = runtime.namespace(config.namespace).component(config.component)
    await component.create_service()

    endpoint = component.endpoint(config.endpoint)
    await register_llm(
        ModelType.Backend, endpoint, config.model_path, config.model_name
    )

136
137
    # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
    # after the lease is revoked
138
    await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
139
140
141
142
143
144
145
146
147
148
149
150
151


def cmd_line_args():
    parser = argparse.ArgumentParser(
        description="SGLang server integrated with Dynamo LLM."
    )
    parser.add_argument(
        "--endpoint",
        type=str,
        default=DEFAULT_ENDPOINT,
        help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
    )
    parser.add_argument(
152
        "--model-path",
153
154
155
156
        type=str,
        default=DEFAULT_MODEL,
        help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
    )
157
158
159
160
161
162
    parser.add_argument(
        "--model-name",
        type=str,
        default="",
        help="Name to serve the model under. Defaults to deriving it from model path.",
    )
163
164
165
166
167
168
169
170
171
    parser.add_argument(
        "--base-gpu-id",
        type=int,
        default=0,
        help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
    )
    parser.add_argument(
        "--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
    )
172
173
174
    parser.add_argument(
        "--kv-block-size", type=int, default=16, help="Size of a KV cache block."
    )
175
176
177
178
179
180
    parser.add_argument(
        "--context-length",
        type=int,
        default=None,
        help="Max model context length. Defaults to models max, usually model_max_length from tokenizer_config.json. Reducing this reduces VRAM requirements.",
    )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    parser.add_argument(
        "--nnodes", type=int, default=1, help="The number of machines SGLang will use"
    )
    parser.add_argument(
        "--node-rank",
        type=int,
        default=0,
        help="Unique number for each node. 0 for the leader.",
    )
    parser.add_argument(
        "--dist-init-addr",
        type=str,
        default="",
        help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
    )
196
197
198
199
200
201
202
203
204
    parser.add_argument(
        "--extra-engine-args",
        type=str,
        default="",
        help="Path to a JSON file containing additional keyword arguments to pass to the SGLang Engine.",
    )
    args = parser.parse_args()

    config = Config()
205
206
207
208
209
210
    config.model_path = args.model_path
    if args.model_name:
        config.model_name = args.model_name
    else:
        # This becomes an `Option` on the Rust side
        config.model_name = None
211
212
213
214

    endpoint_str = args.endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
215
        logging.error(
216
217
218
219
220
221
222
223
224
225
226
            f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )
        sys.exit(1)

    parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts

    config.namespace = parsed_namespace
    config.component = parsed_component_name
    config.endpoint = parsed_endpoint_name
    config.base_gpu_id = args.base_gpu_id
    config.tensor_parallel_size = args.tensor_parallel_size
227
    config.kv_block_size = args.kv_block_size
228
    config.context_length = args.context_length
229
230
231
    config.nnodes = args.nnodes
    config.node_rank = args.node_rank
    config.dist_init_addr = args.dist_init_addr
232
233
234
235
236
237
238
239
    config.extra_engine_args = args.extra_engine_args

    return config


if __name__ == "__main__":
    uvloop.install()
    asyncio.run(worker())