"examples/deployments/router_standalone/worker.py" did not exist on "13a99b7f76c4052c8164d0ca84da5477e5dbcda3"
init_diffusion.py 8.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import os
from typing import Awaitable, Callable

import sglang as sgl

from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.health_check import (
    ImageDiffusionHealthCheckPayload,
    SglangHealthCheckPayload,
    VideoGenerationHealthCheckPayload,
)
from dynamo.sglang.publisher import handle_non_leader_node, setup_sgl_metrics
from dynamo.sglang.register import (
    register_image_diffusion_model,
    register_model_with_readiness_gate,
    register_video_generation_model,
)
from dynamo.sglang.request_handlers import (
    DiffusionWorkerHandler,
    ImageDiffusionWorkerHandler,
    VideoGenerationWorkerHandler,
)


async def init_llm_diffusion(
    runtime: DistributedRuntime,
    config: Config,
    shutdown_event: asyncio.Event,
    shutdown_endpoints: list,
    run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
39
) -> None:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    """Initialize diffusion language model worker component"""
    server_args, dynamo_args = config.server_args, config.dynamo_args

    logging.info(
        f"Initializing diffusion worker with algorithm: {server_args.dllm_algorithm}"
    )
    if server_args.dllm_algorithm_config:
        logging.info(
            f"Using diffusion algorithm config: {server_args.dllm_algorithm_config}"
        )

    if server_args.node_rank >= 1:
        os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"

    engine = sgl.Engine(server_args=server_args)

    generate_endpoint = runtime.endpoint(
        f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
    )

    shutdown_endpoints[:] = [generate_endpoint]

    publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
        engine, config, generate_endpoint
    )

    if server_args.node_rank >= 1:
        await handle_non_leader_node(engine, publisher, metrics_task)
        return

    ready_event = asyncio.Event()

    handler = DiffusionWorkerHandler(
        engine, config, publisher, generate_endpoint, shutdown_event
    )
    handler.register_engine_routes(runtime)

    health_check_payload = SglangHealthCheckPayload(
        engine, use_text_input=dynamo_args.use_sglang_tokenizer
    ).to_dict()

    logging.info(
        f"Registering diffusion model with endpoint types: {dynamo_args.endpoint_types}"
    )

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
                metrics_labels=metrics_labels,
                health_check_payload=health_check_payload,
            ),
            register_model_with_readiness_gate(
                engine,
                generate_endpoint,
                server_args,
                dynamo_args,
                output_type=parse_endpoint_types(dynamo_args.endpoint_types),
                readiness_gate=ready_event,
            ),
        )
    except Exception as e:
        logging.error(f"Failed to serve diffusion endpoints: {e}")
        raise
    finally:
        metrics_task.cancel()
        try:
            await metrics_task
        except asyncio.CancelledError:
            logging.info("Metrics task successfully cancelled")
            pass
        handler.cleanup()
        if run_deferred_handlers is not None:
            logging.info("Running deferred handlers")
            await run_deferred_handlers()


async def init_image_diffusion(
    runtime: DistributedRuntime,
    config: Config,
    shutdown_endpoints: list,
    run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
123
) -> None:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    """Initialize image diffusion worker component"""
    server_args, dynamo_args = config.server_args, config.dynamo_args

    from sglang.multimodal_gen import DiffGenerator

    if not server_args.model_path:
        raise ValueError("--model is required for diffusion workers")

    tp_size = getattr(server_args, "tp_size", 1)
    dp_size = getattr(server_args, "dp_size", 1)
    num_gpus = tp_size * dp_size

    dist_timeout = getattr(server_args, "dist_timeout", None)

    generator = DiffGenerator.from_pretrained(
        model_path=server_args.model_path,
        num_gpus=num_gpus,
        tp_size=tp_size,
        dp_size=dp_size,
        dist_timeout=dist_timeout,
    )

    fs_url = dynamo_args.media_output_fs_url

    generate_endpoint = runtime.endpoint(
        f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
    )

    shutdown_endpoints[:] = [generate_endpoint]

    handler = ImageDiffusionWorkerHandler(
        generator,
        config,
        publisher=None,
        fs=get_fs(fs_url),
    )

    health_check_payload = ImageDiffusionHealthCheckPayload(
        model_path=server_args.model_path
    ).to_dict()

    ready_event = asyncio.Event()

    # The global --output-modalities default is ["text"] which is wrong for
    # image diffusion workers -- it causes the Rust registration path to look
    # for config.json (LLM artefacts).  Only override when the user hasn't
    # explicitly chosen a non-default value.
    output_modalities = dynamo_args.output_modalities
    if output_modalities is None or output_modalities == ["text"]:
        output_modalities = ["image"]
        logging.info(
            "Overriding output_modalities to ['image'] for image diffusion worker"
        )

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
                metrics_labels=[],
                health_check_payload=health_check_payload,
            ),
            register_image_diffusion_model(
                generator,
                generate_endpoint,
                server_args,
                output_modalities=output_modalities,
                readiness_gate=ready_event,
            ),
        )
    except Exception as e:
        logging.error(f"Failed to serve image diffusion endpoints: {e}")
        raise
    finally:
        handler.cleanup()
        if run_deferred_handlers is not None:
            logging.info("Running deferred handlers")
            await run_deferred_handlers()


async def init_video_diffusion(
    runtime: DistributedRuntime,
    config: Config,
    shutdown_endpoints: list,
    run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
209
) -> None:
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    """Initialize video generation worker component"""
    server_args, dynamo_args = config.server_args, config.dynamo_args

    from sglang.multimodal_gen import DiffGenerator

    if not server_args.model_path:
        raise ValueError("--model is required for video generation workers")

    tp_size = getattr(server_args, "tp_size", 1)
    dp_size = getattr(server_args, "dp_size", 1)
    num_gpus = tp_size * dp_size

    dist_timeout = getattr(server_args, "dist_timeout", None)

    generator = DiffGenerator.from_pretrained(
        model_path=server_args.model_path,
        num_gpus=num_gpus,
        tp_size=tp_size,
        dp_size=dp_size,
        dist_timeout=dist_timeout,
    )

    fs_url = dynamo_args.media_output_fs_url

    generate_endpoint = runtime.endpoint(
        f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
    )

    shutdown_endpoints[:] = [generate_endpoint]

    handler = VideoGenerationWorkerHandler(
        generator,
        config,
        publisher=None,
        fs=get_fs(fs_url),
    )

    health_check_payload = VideoGenerationHealthCheckPayload(
        model_path=server_args.model_path
    ).to_dict()

    ready_event = asyncio.Event()

    try:
        await asyncio.gather(
            generate_endpoint.serve_endpoint(
                handler.generate,
                graceful_shutdown=True,
                metrics_labels=[],
                health_check_payload=health_check_payload,
            ),
            register_video_generation_model(
                generator,
                generate_endpoint,
                server_args,
                readiness_gate=ready_event,
            ),
        )
    except Exception as e:
        logging.error(f"Failed to serve video generation endpoints: {e}")
        raise
    finally:
        handler.cleanup()
        if run_deferred_handlers is not None:
            logging.info("Running deferred handlers")
            await run_deferred_handlers()