worker.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
SGLang disaggregated serving flow is

Processor -> PrefillWorker -> DecodeWorker

This is different from how we've implemented the vLLM disaggregated flow.

For now - the SGLangWorker will be responsible for aggreagted and prefill and we will
have a separate DecodeWorker.
"""

27
import asyncio
28
import logging
29
30
import random
import socket
31
from typing import Dict, Union
32
33

import sglang as sgl
34
35
36
from components.decode_worker import SGLangDecodeWorker
from sglang.srt.utils import get_ip
from utils.protocol import DisaggPreprocessedRequest, PreprocessedRequest
37
from utils.sgl_utils import parse_sglang_args
38

39
40
41
42
43
44
45
from dynamo.llm import (
    ModelType,
    WorkerMetricsPublisher,
    ZmqKvEventPublisher,
    ZmqKvEventPublisherConfig,
    register_llm,
)
46
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
47
48
49
50
51
52
53
54
55
56
57
58

logger = logging.getLogger(__name__)


@service(
    dynamo={
        "namespace": "dynamo",
    },
    resources={"gpu": 1},
    workers=1,
)
class SGLangWorker:
59
60
    decode_worker = depends(SGLangDecodeWorker)

61
62
63
64
65
    def __init__(self):
        class_name = self.__class__.__name__
        self.engine_args = parse_sglang_args(class_name, "")
        self.engine = sgl.Engine(server_args=self.engine_args)

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        # Initialize metrics publisher
        self.metrics_publisher = WorkerMetricsPublisher()

    def _update_metrics(self):
        """Update metrics with current engine state"""
        # TODO: remove this once the following upstream changes are merged:
        #   • ai-dynamo/dynamo#1465 – "feat: receive kvmetrics from sglang scheduler"
        #   • sgl-project/sglang#6721 – "Expose runtime KV-cache & request metrics"
        logger.warning(
            "Publishing placeholder metrics in SGLangWorker; these are NOT real engine metrics yet and will be replaced once upstream support lands."
        )
        self.metrics_publisher.publish(
            request_active_slots=1,
            request_total_slots=100,
            kv_active_blocks=random.randint(0, 500),
            kv_total_blocks=1000,
            num_requests_waiting=0,
            gpu_cache_usage_perc=random.uniform(0.1, 0.8),
            gpu_prefix_cache_hit_rate=random.uniform(0.0, 0.5),
        )

    async def create_metrics_publisher_endpoint(self):
        component = dynamo_context["component"]
        await self.metrics_publisher.create_endpoint(component)
90
91
92
93
94
95

    @async_on_start
    async def async_init(self):
        runtime = dynamo_context["runtime"]
        comp_ns, comp_name = SGLangWorker.dynamo_address()  # type: ignore
        endpoint = runtime.namespace(comp_ns).component(comp_name).endpoint("generate")
96
97
98
99
100
        component = runtime.namespace(comp_ns).component(comp_name)

        logger.info(
            f"Registering LLM for discovery with kv block size {self.engine_args.page_size}, endpoint={endpoint}, model_path={self.engine_args.model_path}, served_model_name={self.engine_args.served_model_name}"
        )
101
102
103
104
105
        await register_llm(
            ModelType.Backend,
            endpoint,
            self.engine_args.model_path,
            self.engine_args.served_model_name,
106
            kv_cache_block_size=self.engine_args.page_size,
107
        )
108
109
110
111
112
113
114
115
116
117
118
119
120
121

        self.metrics_publisher.publish(
            request_active_slots=0,
            request_total_slots=1024,
            kv_active_blocks=0,
            kv_total_blocks=1024,
            num_requests_waiting=0,
            gpu_cache_usage_perc=0.0,
            gpu_prefix_cache_hit_rate=0.0,
        )

        # Create metrics publisher endpoint for KV router discovery
        asyncio.create_task(self.create_metrics_publisher_endpoint())

122
123
124
125
126
127
128
129
130
131
        if self.engine_args.disaggregation_mode:
            self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info()
            comp_ns, comp_name = SGLangDecodeWorker.dynamo_address()  # type: ignore
            self.decode_client = (
                await runtime.namespace(comp_ns)
                .component(comp_name)
                .endpoint("generate")
                .client()
            )

132
133
134
135
136
137
138
139
140
141
142
143
        # Configure ZMQ KV Event Publisher to relay KV events from SGLang to NATS
        zmq_config = ZmqKvEventPublisherConfig(
            worker_id=endpoint.lease_id(),
            kv_block_size=self.engine_args.page_size,  # Keep in sync with register_llm above
        )

        # Keep a reference on the instance to avoid the publisher being garbage-collected.
        self._kv_event_publisher = ZmqKvEventPublisher(
            component=component,
            config=zmq_config,
        )

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    def _get_bootstrap_info(self):
        """
        Bootstrap info is stored in the worker's tokenizer manager. We use it to
        add servers to the bootstrap_room
        """
        inner_tm = self.engine.tokenizer_manager
        bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port

        # multinode check
        if inner_tm.server_args.dist_init_addr:
            bootstrap_host = socket.gethostbyname(
                inner_tm.server_args.dist_init_addr.split(":")[0]
            )
        else:
            bootstrap_host = get_ip()

        return bootstrap_host, bootstrap_port
161
162
163
164
165
166
167
168
169
170
171
172
173
174

    def _build_sampling_params(self, request: PreprocessedRequest) -> dict:
        sampling_params = {}
        if request.sampling_options.temperature:
            sampling_params["temperature"] = request.sampling_options.temperature
        if request.sampling_options.top_p:
            sampling_params["top_p"] = request.sampling_options.top_p
        if request.sampling_options.top_k:
            sampling_params["top_k"] = request.sampling_options.top_k
        sampling_params["max_new_tokens"] = request.stop_conditions.max_tokens
        if request.stop_conditions.ignore_eos:
            sampling_params["ignore_eos"] = request.stop_conditions.ignore_eos
        return sampling_params

175
176
177
178
179
180
181
182
183
184
    def _get_request_batch_size(self, request: PreprocessedRequest):
        """Get batch size from request, returns None for single requests"""
        if request.batch_token_ids is not None:
            return len(request.batch_token_ids)
        return None

    def _is_batch_request(self, request: PreprocessedRequest):
        """Check if request is in batch mode"""
        return request.batch_token_ids is not None

185
    @endpoint()
186
    async def generate(self, request: PreprocessedRequest):
187
188
189
190
        # Check if we're in batch mode at the start
        is_batch = self._is_batch_request(request)
        batch_size = self._get_request_batch_size(request)

191
192
        # TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
        sampling_params = self._build_sampling_params(request)
193
194

        if self.engine_args.disaggregation_mode != "null":
195
196
197
198
199
200
201
202
203
204
            if is_batch:
                bootstrap_room = [
                    self._generate_bootstrap_room() for _ in range(batch_size)
                ]
                bootstrap_host = [self.bootstrap_host] * batch_size
                bootstrap_port = [self.bootstrap_port] * batch_size
            else:
                bootstrap_host = self.bootstrap_host
                bootstrap_port = self.bootstrap_port
                bootstrap_room = self._generate_bootstrap_room()
205
206
207
208
209

            # decode worker request
            disagg_request = DisaggPreprocessedRequest(
                request=request,
                sampling_params=sampling_params,
210
211
                bootstrap_host=bootstrap_host,
                bootstrap_port=bootstrap_port,
212
213
214
215
216
                bootstrap_room=bootstrap_room,
            )

            # prefill response is not used
            prefill = await self.engine.async_generate(
217
218
219
                input_ids=request.token_ids
                if not is_batch
                else request.batch_token_ids,
220
221
                sampling_params=sampling_params,
                stream=True,
222
223
                bootstrap_host=bootstrap_host,
                bootstrap_port=bootstrap_port,
224
225
226
227
228
229
                bootstrap_room=bootstrap_room,
            )
            prefill_task = asyncio.create_task(self._prefill_generator(prefill))

            decode = await self.decode_client.generate(disagg_request.model_dump_json())

230
231
232
            async for out in self._process_stream(
                decode, unpack=True, is_batch=is_batch
            ):
233
234
235
236
237
                yield out

            await prefill_task
        else:
            g = await self.engine.async_generate(
238
239
240
                input_ids=request.token_ids
                if not is_batch
                else request.batch_token_ids,
241
242
243
244
                sampling_params=sampling_params,
                stream=True,
            )

245
            async for out in self._process_stream(g, unpack=False, is_batch=is_batch):
246
247
                yield out

248
249
250
251
252
253
254
255
    async def _process_stream(self, stream_source, unpack: bool, is_batch: bool):
        # Initialize based on batch mode
        num_output_tokens_so_far: Union[Dict[int, int], int]
        if is_batch:
            num_output_tokens_so_far = {}
        else:
            num_output_tokens_so_far = 0

256
257
258
        async for res in stream_source:
            data = res.data() if unpack else res
            finish_reason = data["meta_info"]["finish_reason"]
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

            if is_batch:
                # Handle batch response
                assert isinstance(num_output_tokens_so_far, dict)
                index = data.get("index", 0)
                if index not in num_output_tokens_so_far:
                    num_output_tokens_so_far[index] = 0

                if finish_reason:
                    out = {
                        "token_ids": [],
                        "finish_reason": finish_reason["type"],
                        "index": index,
                    }
                else:
                    next_total_toks = len(data["output_ids"])
                    new_tokens = data["output_ids"][num_output_tokens_so_far[index] :]
                    out = {
                        "token_ids": new_tokens,
                        "index": index,
                    }
                    num_output_tokens_so_far[index] = next_total_toks
281
            else:
282
283
284
285
286
287
288
289
290
                # Handle single response
                assert isinstance(num_output_tokens_so_far, int)
                if finish_reason:
                    out = {"token_ids": [], "finish_reason": finish_reason["type"]}
                else:
                    next_total_toks = len(data["output_ids"])
                    out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]}
                    num_output_tokens_so_far = next_total_toks

291
            yield out
292
293
294
295
296
297
298

    def _generate_bootstrap_room(self):
        return random.randint(0, 2**63 - 1)

    async def _prefill_generator(self, prefill):
        async for _ in prefill:
            pass