"examples/multimodal/launch/agg.sh" did not exist on "1327e3bbdb8ba0aaf75bdba9963893c513c85254"
sglang_inc.py 7.15 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
5
# `dynamo-run out=sglang` runs this script
# Can also be used standalone: `python3 sglang_inc.py` - lots of optional cmd line params
6
7
8

import argparse
import asyncio
9
import logging
10
import sys
11
from typing import Optional
12
13
14
15
16
17
18
19

import sglang
import uvloop
from sglang.srt.server_args import ServerArgs

from dynamo.llm import ModelType, register_llm
from dynamo.runtime import DistributedRuntime, dynamo_worker

20
# Only used if you run it manually from the command line
21
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
22
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
23

24
25
logging.basicConfig(level=logging.DEBUG)

26
27
28
29
30
31
32

class Config:
    """Command line parameters or defaults"""

    namespace: str
    component: str
    endpoint: str
33
34
    model_path: str
    model_name: Optional[str]
35
36
    base_gpu_id: int
    tensor_parallel_size: int
37
    kv_block_size: int
38
39
40
    nnodes: int
    node_rank: int
    dist_init_addr: str
41
42
43
44
45
46
47
48
49
50
51
52
    extra_engine_args: str


class RequestHandler:
    """
    Request handler for the generate endpoint
    """

    def __init__(self, engine):
        self.engine_client = engine

    async def generate(self, request):
53
        sampling_params = {}
54
55
56
57
58
59
        if request["sampling_options"]["temperature"] is not None:
            sampling_params["temperature"] = request["sampling_options"]["temperature"]
        sampling_params = {
            # sglang defaults this to 128
            "max_new_tokens": request["stop_conditions"]["max_tokens"],
        }
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        num_output_tokens_so_far = 0
        gen = await self.engine_client.async_generate(
            input_ids=request["token_ids"], sampling_params=sampling_params, stream=True
        )
        async for res in gen:
            # res is a dict

            finish_reason = res["meta_info"]["finish_reason"]
            if finish_reason:
                # Don't forward the stop token
                out = {"token_ids": [], "finish_reason": finish_reason["type"]}
            else:
                next_total_toks = len(res["output_ids"])
                out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]}
            yield out
            num_output_tokens_so_far = next_total_toks


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
    await init(runtime, cmd_line_args())


async def init(runtime: DistributedRuntime, config: Config):
    """
    Instantiate and serve
    """
    component = runtime.namespace(config.namespace).component(config.component)
    await component.create_service()

    endpoint = component.endpoint(config.endpoint)
91
92
93
    await register_llm(
        ModelType.Backend, endpoint, config.model_path, config.model_name
    )
94
95

    arg_map = {
96
        "model_path": config.model_path,
97
98
99
        "skip_tokenizer_init": True,
        "tp_size": config.tensor_parallel_size,
        "base_gpu_id": config.base_gpu_id,
100
        "page_size": config.kv_block_size,
101
    }
102
103
104
105
106
107
108
    if config.dist_init_addr != "":
        arg_map["trust_remote_code"] = True
        arg_map["nnodes"] = config.nnodes
        arg_map["dist_init_addr"] = config.dist_init_addr
        # In practice this is always 0 because Dynamo only manages the leader
        arg_map["node_rank"] = config.node_rank

109
110
111
112
113
114
115
116
117
118
119
120
121
    if config.extra_engine_args != "":
        json_map = {}
        # extra_engine_args is a filename
        try:
            with open(config.extra_engine_args) as f:
                json_map = json.load(f)
        except FileNotFoundError:
            logging.error(f"File {config.extra_engine_args} not found.")
        except json.JSONDecodeError as e:
            logging.error(f"Invalid JSON in {config.extra_engine_args}: {e}")
        logging.debug(f"Adding extra engine arguments: {json_map}")
        arg_map = {**arg_map, **json_map}  # json_map gets precedence

122
123
    # TODO fetch default SamplingParams from generation_config.json

124
125
126
127
128
    engine_args = ServerArgs(**arg_map)
    engine_client = sglang.Engine(server_args=engine_args)

    # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
    # after the lease is revoked
129
    await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
130
131
132
133
134
135
136
137
138
139
140
141
142


def cmd_line_args():
    parser = argparse.ArgumentParser(
        description="SGLang server integrated with Dynamo LLM."
    )
    parser.add_argument(
        "--endpoint",
        type=str,
        default=DEFAULT_ENDPOINT,
        help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
    )
    parser.add_argument(
143
        "--model-path",
144
145
146
147
        type=str,
        default=DEFAULT_MODEL,
        help=f"Path to disk model or HuggingFace model identifier to load. Default: {DEFAULT_MODEL}",
    )
148
149
150
151
152
153
    parser.add_argument(
        "--model-name",
        type=str,
        default="",
        help="Name to serve the model under. Defaults to deriving it from model path.",
    )
154
155
156
157
158
159
160
161
162
    parser.add_argument(
        "--base-gpu-id",
        type=int,
        default=0,
        help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
    )
    parser.add_argument(
        "--tensor-parallel-size", type=int, default=1, help="Number of GPUs to use."
    )
163
164
165
    parser.add_argument(
        "--kv-block-size", type=int, default=16, help="Size of a KV cache block."
    )
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    parser.add_argument(
        "--nnodes", type=int, default=1, help="The number of machines SGLang will use"
    )
    parser.add_argument(
        "--node-rank",
        type=int,
        default=0,
        help="Unique number for each node. 0 for the leader.",
    )
    parser.add_argument(
        "--dist-init-addr",
        type=str,
        default="",
        help="Host address (e.g., `192.168.0.2:25000`) of the node with rank 0",
    )
181
182
183
184
185
186
187
188
189
    parser.add_argument(
        "--extra-engine-args",
        type=str,
        default="",
        help="Path to a JSON file containing additional keyword arguments to pass to the SGLang Engine.",
    )
    args = parser.parse_args()

    config = Config()
190
191
192
193
194
195
    config.model_path = args.model_path
    if args.model_name:
        config.model_name = args.model_name
    else:
        # This becomes an `Option` on the Rust side
        config.model_name = None
196
197
198
199

    endpoint_str = args.endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
200
        logging.error(
201
202
203
204
205
206
207
208
209
210
211
            f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )
        sys.exit(1)

    parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts

    config.namespace = parsed_namespace
    config.component = parsed_component_name
    config.endpoint = parsed_endpoint_name
    config.base_gpu_id = args.base_gpu_id
    config.tensor_parallel_size = args.tensor_parallel_size
212
    config.kv_block_size = args.kv_block_size
213
214
215
    config.nnodes = args.nnodes
    config.node_rank = args.node_rank
    config.dist_init_addr = args.dist_init_addr
216
217
218
219
220
221
222
223
    config.extra_engine_args = args.extra_engine_args

    return config


if __name__ == "__main__":
    uvloop.install()
    asyncio.run(worker())