handler_base.py 10.5 KB
Newer Older
1
2
3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

4
import asyncio
5
6
import base64
import json
7
import logging
8
9
import random
import socket
10
from abc import ABC, abstractmethod
11
12
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Optional, Tuple
13
14

import sglang as sgl
15
from sglang.srt.tracing import trace as sglang_trace
16
from sglang.srt.utils import get_local_ip_auto
17

18
from dynamo._core import Component, Context
19
from dynamo.common.utils.input_params import InputParamManager
20
from dynamo.sglang.args import Config
21
from dynamo.sglang.publisher import DynamoSglangPublisher
22
23
24


class BaseWorkerHandler(ABC):
25
26
    """Abstract base class for SGLang worker handlers."""

27
28
29
30
31
    def __init__(
        self,
        component: Component,
        engine: sgl.Engine,
        config: Config,
32
33
34
35
36
37
38
39
40
41
        publisher: Optional[DynamoSglangPublisher] = None,
    ) -> None:
        """Initialize base worker handler.

        Args:
            component: The Dynamo runtime component.
            engine: The SGLang engine instance.
            config: SGLang and Dynamo configuration.
            publisher: Optional metrics publisher for the worker.
        """
42
43
44
        self.component = component
        self.engine = engine
        self.config = config
45
46
47
48
49
50
        if publisher is not None:
            self.metrics_publisher = publisher.metrics_publisher
            self.kv_publisher = publisher.kv_publisher
        else:
            self.metrics_publisher = None
            self.kv_publisher = None
51
        self.serving_mode = config.serving_mode
52
        self.skip_tokenizer_init = config.server_args.skip_tokenizer_init
53
        self.enable_trace = config.server_args.enable_trace
54

55
56
57
58
59
60
        self.input_param_manager = InputParamManager(
            self.engine.tokenizer_manager.tokenizer
            if not self.skip_tokenizer_init
            else None
        )

61
    @abstractmethod
62
    async def generate(self, request: Dict[str, Any], context: Context):
63
64
65
66
        """Generate response from request.

        Args:
            request: Request dict with input and parameters.
67
            context: Context object for cancellation handling.
68
69
70
71

        Yields:
            Response data (format varies by handler implementation).
        """
72
73
        pass

74
75
    def cleanup(self) -> None:
        """Cleanup resources. Override in subclasses as needed."""
76
        pass
77

78
    def _get_input_param(self, request: Dict[str, Any]) -> Dict[str, Any]:
79
80
81
        request_input = self.input_param_manager.get_input_param(
            request, use_tokenizer=not self.skip_tokenizer_init
        )
82

83
84
85
        return {
            "prompt" if isinstance(request_input, str) else "input_ids": request_input
        }
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    @staticmethod
    def _generate_bootstrap_room() -> int:
        """Generate a unique bootstrap room ID for disaggregated serving.

        Returns:
            Random 63-bit integer.
        """
        return random.randint(0, 2**63 - 1)

    @staticmethod
    def _get_bootstrap_info(engine: sgl.Engine) -> Tuple[str, int]:
        """Extract bootstrap host and port from SGLang engine.

        Args:
            engine: The SGLang engine instance.

        Returns:
            Tuple of (bootstrap_host, bootstrap_port).
        """
        inner_tm = engine.tokenizer_manager
        bootstrap_port = inner_tm.server_args.disaggregation_bootstrap_port

        if inner_tm.server_args.dist_init_addr:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            # IPv6-ready host extraction and resolution:
            # 1) Extract raw host from "host:port" or "[IPv6]:port"/"[IPv6]".
            # 2) Resolve via AF_UNSPEC to accept A/AAAA and literals.
            # 3) Bracket-wrap IPv6 for safe "{host}:{port}" URL formatting.
            addr = inner_tm.server_args.dist_init_addr.strip()
            if addr.startswith("["):
                end = addr.find("]")
                host_core = addr[1:end] if end != -1 else addr.strip("[]")
            else:
                # Only treat single ':' with numeric suffix as host:port; otherwise it's an IPv6/FQDN host.
                if addr.count(":") == 1:
                    host_candidate, maybe_port = addr.rsplit(":", 1)
                    host_core = host_candidate if maybe_port.isdigit() else addr
                else:
                    host_core = addr
            try:
                infos = socket.getaddrinfo(
                    host_core,
                    None,
                    family=socket.AF_UNSPEC,
                    type=socket.SOCK_STREAM,
                )
                resolved = infos[0][4][0]  # let OS policy pick v4/v6
                bootstrap_host = resolved
            except socket.gaierror:
                # Fallback: keep literal/FQDN as-is (still wrap IPv6 below)
                bootstrap_host = host_core
137
        else:
138
            bootstrap_host = get_local_ip_auto()
139

140
141
142
143
        # Wrap IPv6 literal with brackets so f"{host}:{port}" stays valid.
        if ":" in bootstrap_host and not bootstrap_host.startswith("["):
            bootstrap_host = f"[{bootstrap_host}]"

144
        return bootstrap_host, bootstrap_port
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    def _propagate_trace_context_to_sglang(
        self, context: Context, bootstrap_room: int = 0
    ):
        """Propagate Dynamo's trace context to SGLang for distributed tracing. SGLang expects a certain
        format derived by loooking at https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/tracing/trace.py
        in the to_dict() method.

        Args:
            context: Dynamo Context object containing trace information.
            bootstrap_room: Bootstrap room ID (0 for aggregated, actual room for disaggregated).
        """
        trace_id = context.trace_id
        span_id = context.span_id
        if not trace_id or not span_id:
            return

        # Build trace context for SGLang
        trace_context = {
            str(bootstrap_room): {
                "root_span": {"traceparent": f"00-{trace_id}-{span_id}-01"},
                "prev_span": {
                    "span_id": int(span_id, 16),
                    "trace_id": int(trace_id, 16),
                },
            }
        }

        # Encode and propagate
        base64_context = base64.b64encode(
            json.dumps(trace_context, ensure_ascii=False).encode("utf-8")
        ).decode("utf-8")
        sglang_trace.trace_set_remote_propagate_context(base64_context)

179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    async def _handle_cancellation(
        self, request_id_future: asyncio.Future, context: Context
    ):
        """Background task to handle cancellation by monitoring context state.

        Args:
            request_id_future: Future that will be set with the SGLang request ID
                              when the first response arrives.
            context: Context object for cancellation handling.
        """
        try:
            logging.debug(f"Cancellation monitor started for Context: {context.id()}")

            # Always wait for the request ID to ensure we can abort the request
            sglang_request_id = await request_id_future
            logging.debug(
                f"Cancellation monitor received SGLang Request ID {sglang_request_id} for Context: {context.id()}"
            )
            logging.debug(f"Request ID future cancelled for Context: {context.id()}")

            await context.async_killed_or_stopped()

            logging.info(
                f"Cancellation signal received for SGLang Request ID {sglang_request_id}, Context: {context.id()}"
            )

            # Call abort_request on the tokenizer_manager through the engine
            if (
                hasattr(self.engine, "tokenizer_manager")
                and self.engine.tokenizer_manager
            ):
                logging.info(
                    f"Calling SGLang abort_request for Request ID {sglang_request_id}"
                )
                self.engine.tokenizer_manager.abort_request(
                    rid=sglang_request_id, abort_all=False
                )
                logging.info(f"Aborted Request ID: {context.id()}")
            else:
                logging.error(
                    f"SGLang tokenizer_manager not found for abort request: {context.id()}"
                )
        except asyncio.CancelledError:
            # Task was cancelled, which is expected when generation completes
            request_id = "unknown"
            if request_id_future.done() and not request_id_future.cancelled():
                try:
                    request_id = request_id_future.result()
                except Exception:
                    pass
            logging.debug(
                f"Cancellation monitor task cancelled for SGLang Request ID {request_id}, Context: {context.id()}"
            )
            raise

    @asynccontextmanager
    async def _cancellation_monitor(
        self, request_id_future: asyncio.Future, context: Context
    ) -> AsyncGenerator[asyncio.Task, None]:
        """
        Context manager for monitoring request cancellation.
        Automatically creates a background task to monitor for cancellation and
        cleans it up when the context exits.

        Args:
            request_id_future: Future that will be set with the SGLang request ID
                              when the first response arrives.
            context: Context object for cancellation handling

        Yields:
            asyncio.Task: The cancellation monitoring task being managed
        """
251
        logging.debug(f"Creating cancellation monitor task for Context: {context.id()}")
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        # Start the cancellation monitoring task
        cancellation_task = asyncio.create_task(
            self._handle_cancellation(request_id_future, context)
        )

        try:
            yield cancellation_task
        finally:
            # Clean up the background cancellation task
            request_id = "unknown"
            if request_id_future.done() and not request_id_future.cancelled():
                try:
                    request_id = request_id_future.result()
                except Exception:
                    pass

            if not cancellation_task.done():
                logging.debug(
                    f"Cancelling cancellation monitor task for SGLang Request ID {request_id}, Context: {context.id()}"
                )
                cancellation_task.cancel()
                try:
                    await cancellation_task
                except asyncio.CancelledError:
                    pass
            else:
                logging.debug(
                    f"Cancellation monitor task already completed for SGLang Request ID {request_id}, Context: {context.id()}"
                )