"docs/backends/trtllm/README.md" did not exist on "e542f00356e0d0a4d002054ea3c49e4ee759cb2a"
worker.py 10.6 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
import asyncio
5
import logging
6
7
import random
import socket
8
9
import sys
from typing import Any, Dict, Optional, Union
10
11

import sglang as sgl
12
13
import uvloop
from sglang.srt.server_args import ServerArgs
14
from sglang.srt.utils import get_ip
15
16
from utils.protocol import DisaggPreprocessedRequest
from utils.sgl_utils import parse_sglang_args_inc
17

18
19
20
21
22
23
24
from dynamo.llm import (
    ModelType,
    WorkerMetricsPublisher,
    ZmqKvEventPublisher,
    ZmqKvEventPublisherConfig,
    register_llm,
)
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging

configure_dynamo_logging()


class RequestHandler:
    def __init__(
        self,
        engine: sgl.Engine,
        server_args: ServerArgs,
        component,
        decode_client: Optional[Any] = None,
    ):
        self.engine = engine
        self.server_args = server_args
        self.component = component
        self.metrics_publisher = WorkerMetricsPublisher()
43

44
45
46
47
48
49
50
51
52
53
        if server_args.disaggregation_mode != "null":
            self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info()
            if decode_client is None:
                raise ValueError(
                    "decode_client must be provided when disaggregation_mode is not 'null'"
                )
            self.decode_client = decode_client
            logging.info(
                f"Disaggregation enabled - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
            )
54

55
        logging.info("Request handler initialized")
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    def setup_metrics(self):
        """Set up metrics publisher - call this after handler creation"""
        self.metrics_publisher.publish(
            request_active_slots=0,
            request_total_slots=1024,
            kv_active_blocks=0,
            kv_total_blocks=1024,
            num_requests_waiting=0,
            gpu_cache_usage_perc=0.0,
            gpu_prefix_cache_hit_rate=0.0,
        )
        task = asyncio.create_task(self.create_metrics_publisher_endpoint())
        task.add_done_callback(
            lambda _: logging.debug("metrics publisher endpoint created")
        )
72

73
74
75
    async def create_metrics_publisher_endpoint(self):
        logging.debug("Creating metrics publisher endpoint")
        await self.metrics_publisher.create_endpoint(self.component)
76
77
78
79
80

    def _update_metrics(self):
        """Update metrics with current engine state"""
        # TODO: remove this once the following upstream changes are merged:
        #   • sgl-project/sglang#6721 – "Expose runtime KV-cache & request metrics"
81
        logging.warning(
82
83
84
85
86
87
88
89
90
91
92
93
            "Publishing placeholder metrics in SGLangWorker; these are NOT real engine metrics yet and will be replaced once upstream support lands."
        )
        self.metrics_publisher.publish(
            request_active_slots=1,
            request_total_slots=100,
            kv_active_blocks=random.randint(0, 500),
            kv_total_blocks=1000,
            num_requests_waiting=0,
            gpu_cache_usage_perc=random.uniform(0.1, 0.8),
            gpu_prefix_cache_hit_rate=random.uniform(0.0, 0.5),
        )

94
    def _get_bootstrap_info(self):
95
        """Bootstrap info from tokenizer manager"""
96
97
98
99
100
101
102
103
104
105
106
        inner_tm = self.engine.tokenizer_manager
        bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port

        if inner_tm.server_args.dist_init_addr:
            bootstrap_host = socket.gethostbyname(
                inner_tm.server_args.dist_init_addr.split(":")[0]
            )
        else:
            bootstrap_host = get_ip()

        return bootstrap_host, bootstrap_port
107

108
    def _build_sampling_params(self, request: dict) -> dict:
109
        sampling_params = {}
110
111
112
113
114
115
116
117
118
        if request["sampling_options"]["temperature"]:
            sampling_params["temperature"] = request["sampling_options"]["temperature"]
        if request["sampling_options"]["top_p"]:
            sampling_params["top_p"] = request["sampling_options"]["top_p"]
        if request["sampling_options"]["top_k"]:
            sampling_params["top_k"] = request["sampling_options"]["top_k"]
        sampling_params["max_new_tokens"] = request["stop_conditions"]["max_tokens"]
        if request["stop_conditions"]["ignore_eos"]:
            sampling_params["ignore_eos"] = request["stop_conditions"]["ignore_eos"]
119
120
        return sampling_params

121
    def _get_request_batch_size(self, request: dict):
122
        """Get batch size from request, returns None for single requests"""
123
124
        if request["batch_token_ids"] is not None:
            return len(request["batch_token_ids"])
125
126
        return None

127
    def _is_batch_request(self, request: dict):
128
        """Check if request is in batch mode"""
129
130
131
132
        return request["batch_token_ids"] is not None

    def _generate_bootstrap_room(self):
        return random.randint(0, 2**63 - 1)
133

134
    async def generate(self, request: dict):
135
136
137
        is_batch = self._is_batch_request(request)
        batch_size = self._get_request_batch_size(request)

138
139
        # TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
        sampling_params = self._build_sampling_params(request)
140

141
        if self.server_args.disaggregation_mode != "null":
142
143
144
145
146
147
148
149
150
151
            if is_batch:
                bootstrap_room = [
                    self._generate_bootstrap_room() for _ in range(batch_size)
                ]
                bootstrap_host = [self.bootstrap_host] * batch_size
                bootstrap_port = [self.bootstrap_port] * batch_size
            else:
                bootstrap_host = self.bootstrap_host
                bootstrap_port = self.bootstrap_port
                bootstrap_room = self._generate_bootstrap_room()
152
153
154
155
156

            # decode worker request
            disagg_request = DisaggPreprocessedRequest(
                request=request,
                sampling_params=sampling_params,
157
158
                bootstrap_host=bootstrap_host,
                bootstrap_port=bootstrap_port,
159
160
161
162
163
                bootstrap_room=bootstrap_room,
            )

            # prefill response is not used
            prefill = await self.engine.async_generate(
164
                input_ids=request["token_ids"]
165
                if not is_batch
166
                else request["batch_token_ids"],
167
168
                sampling_params=sampling_params,
                stream=True,
169
170
                bootstrap_host=bootstrap_host,
                bootstrap_port=bootstrap_port,
171
172
173
174
175
176
                bootstrap_room=bootstrap_room,
            )
            prefill_task = asyncio.create_task(self._prefill_generator(prefill))

            decode = await self.decode_client.generate(disagg_request.model_dump_json())

177
178
179
            async for out in self._process_stream(
                decode, unpack=True, is_batch=is_batch
            ):
180
181
182
183
184
                yield out

            await prefill_task
        else:
            g = await self.engine.async_generate(
185
                input_ids=request["token_ids"]
186
                if not is_batch
187
                else request["batch_token_ids"],
188
189
190
191
                sampling_params=sampling_params,
                stream=True,
            )

192
            async for out in self._process_stream(g, unpack=False, is_batch=is_batch):
193
194
                yield out

195
196
197
198
199
200
201
202
    async def _process_stream(self, stream_source, unpack: bool, is_batch: bool):
        # Initialize based on batch mode
        num_output_tokens_so_far: Union[Dict[int, int], int]
        if is_batch:
            num_output_tokens_so_far = {}
        else:
            num_output_tokens_so_far = 0

203
204
205
        async for res in stream_source:
            data = res.data() if unpack else res
            finish_reason = data["meta_info"]["finish_reason"]
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

            if is_batch:
                # Handle batch response
                assert isinstance(num_output_tokens_so_far, dict)
                index = data.get("index", 0)
                if index not in num_output_tokens_so_far:
                    num_output_tokens_so_far[index] = 0

                if finish_reason:
                    out = {
                        "token_ids": [],
                        "finish_reason": finish_reason["type"],
                        "index": index,
                    }
                else:
                    next_total_toks = len(data["output_ids"])
                    new_tokens = data["output_ids"][num_output_tokens_so_far[index] :]
                    out = {
                        "token_ids": new_tokens,
                        "index": index,
                    }
                    num_output_tokens_so_far[index] = next_total_toks
228
            else:
229
230
231
232
233
234
235
236
237
                # Handle single response
                assert isinstance(num_output_tokens_so_far, int)
                if finish_reason:
                    out = {"token_ids": [], "finish_reason": finish_reason["type"]}
                else:
                    next_total_toks = len(data["output_ids"])
                    out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]}
                    num_output_tokens_so_far = next_total_toks

238
            yield out
239
240
241
242

    async def _prefill_generator(self, prefill):
        async for _ in prefill:
            pass
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294


@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
    server_args = parse_sglang_args_inc(sys.argv[1:])
    await init(runtime, server_args)


async def init(runtime: DistributedRuntime, server_args: ServerArgs):
    """Initialize worker (either prefill or aggregated)"""

    engine = sgl.Engine(server_args=server_args)

    component = runtime.namespace("dynamo").component("worker")
    await component.create_service()

    endpoint = component.endpoint("generate")
    await register_llm(
        ModelType.Backend,
        endpoint,
        server_args.model_path,
        server_args.served_model_name,
        kv_cache_block_size=server_args.page_size,
    )

    if server_args.disaggregation_mode != "null":
        decode_client = (
            await runtime.namespace("dynamo")
            .component("decode")
            .endpoint("generate")
            .client()
        )
        handler = RequestHandler(engine, server_args, component, decode_client)
    else:
        handler = RequestHandler(engine, server_args, component)

    # Set up metrics in background
    handler.setup_metrics()

    # Set up ZMQ kv event publisher
    zmq_config = ZmqKvEventPublisherConfig(
        worker_id=endpoint.lease_id(),
        kv_block_size=server_args.page_size,
    )
    _ = ZmqKvEventPublisher(component=component, config=zmq_config)

    await endpoint.serve_endpoint(handler.generate)


if __name__ == "__main__":
    uvloop.install()
    asyncio.run(worker())