"examples/llm/benchmarks/disagg.py" did not exist on "c9130f8f8ce264379131e9ee2973534fe4cbf713"
utils.py 10 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
from contextlib import asynccontextmanager
from typing import ClassVar, Optional

import msgspec
from nats.aio.client import Client as NATS
from nats.errors import Error as NatsError
from nats.js.client import JetStreamContext
from nats.js.errors import NotFoundError
from pydantic import BaseModel
from vllm.distributed.device_communicators.nixl import NixlMetadata

from dynamo._core import Client
from dynamo.runtime import DistributedRuntime

logger = logging.getLogger(__name__)


class GeneralRequest(BaseModel):
    prompt: str = "user input"
    request_id: str = "id_string"


class GeneralResponse(BaseModel):
    worker_output: str = "generated output"
    request_id: str = "id_string"


class RemotePrefillRequest(msgspec.Struct, omit_defaults=True, dict=True):
    engine_id: str = "Engine ID"
    request_id: str = "id_string"


class NATSQueue:
    _instance: ClassVar[Optional["NATSQueue"]] = None
    _lock: ClassVar[asyncio.Lock] = asyncio.Lock()

    def __init__(
        self,
        stream_name: str = "default",
        nats_server: str = "nats://localhost:4222",
        dequeue_timeout: float = 1,
    ):
        self.nats_url = nats_server
        self._nc: Optional[NATS] = None
        self._js: Optional[JetStreamContext] = None
        # TODO: check if this is needed
        # Sanitize stream_name to remove path separators
        self._stream_name = stream_name.replace("/", "_").replace("\\", "_")
        self._subject = f"{self._stream_name}.*"
        self.dequeue_timeout = dequeue_timeout
        self._subscriber: Optional[JetStreamContext.PullSubscription] = None

    @classmethod
    @asynccontextmanager
    async def get_instance(
        cls,
        *,
        stream_name: str = "default",
        nats_server: str = "nats://localhost:4222",
        dequeue_timeout: float = 1,
    ):
        """Get or create a singleton instance of NATSq"""
        # TODO: check if this _lock is needed with GIL
        async with cls._lock:
            if cls._instance is None:
                cls._instance = cls(
                    stream_name=stream_name,
                    nats_server=nats_server,
                    dequeue_timeout=dequeue_timeout,
                )
                await cls._instance.connect()
            try:
                yield cls._instance
            except Exception:
                if cls._instance:
                    await cls._instance.close()
                cls._instance = None
                raise

    # TODO: check to see if this can be replaced by something like get_instance().close()
    @classmethod
    async def shutdown(cls):
        """Explicitly close the singleton instance if it exists"""
        async with cls._lock:
            if cls._instance:
                await cls._instance.close()
                cls._instance = None

    async def connect(self):
        """Establish connection and create stream if needed"""
        try:
            if self._nc is None:
                self._nc = NATS()
                await self._nc.connect(self.nats_url)
                self._js = self._nc.jetstream()
                # Check if stream exists, if not create it
                try:
                    await self._js.stream_info(self._stream_name)
                except NotFoundError:
                    await self._js.add_stream(
                        name=self._stream_name, subjects=[self._subject]
                    )
                # Create persistent subscriber
                self._subscriber = await self._js.pull_subscribe(
                    f"{self._stream_name}.queue", durable="worker-group"
                )
        except NatsError as e:
            await self.close()
            raise ConnectionError(f"Failed to connect to NATS: {e}")

    async def ensure_connection(self):
        """Ensure we have an active connection"""
        if self._nc is None or self._nc.is_closed:
            await self.connect()

    async def close(self):
        """Close the connection when done"""
        if self._nc:
            await self._nc.close()
            self._nc = None
            self._js = None
            self._subscriber = None

    # TODO: is enqueue/dequeue_object a better name for a general queue?
    async def enqueue_task(self, task_data: bytes) -> None:
        """
        Enqueue a task using msgspec-encoded data
        """
        await self.ensure_connection()
        try:
            await self._js.publish(f"{self._stream_name}.queue", task_data)  # type: ignore
        except NatsError as e:
            raise RuntimeError(f"Failed to enqueue task: {e}")

    async def dequeue_task(self) -> Optional[bytes]:
        """Dequeue and return a task as raw bytes, to be decoded with msgspec"""
        await self.ensure_connection()
        try:
            msgs = await self._subscriber.fetch(1, timeout=self.dequeue_timeout)  # type: ignore
            if msgs:
                msg = msgs[0]
                await msg.ack()
                return msg.data
            return None
        except asyncio.TimeoutError:
            return None
        except NatsError as e:
            raise RuntimeError(f"Failed to dequeue task: {e}")

    async def get_queue_size(self) -> int:
        """Get the number of messages currently in the queue"""
        await self.ensure_connection()
        try:
            # Get consumer info to get pending messages count
            consumer_info = await self._js.consumer_info(  # type: ignore
                self._stream_name, "worker-group"
            )
            # Return number of pending messages (real-time queue size)
            return consumer_info.num_pending
        except NatsError as e:
            raise RuntimeError(f"Failed to get queue size: {e}")


class PrefillQueue(NATSQueue):
    """
    A wrapper of NATSQueue for PrefillRequest.
    The stream name is forced to be "prefill_queue".
    """

    def __init__(
        self,
        stream_name="prefill_queue",
        nats_server: str = "nats://localhost:4222",
        dequeue_timeout: float = 1,
    ):
        super().__init__(
            stream_name=stream_name,
            nats_server=nats_server,
            dequeue_timeout=dequeue_timeout,
        )

    async def enqueue_prefill_request(
        self, prefill_request: RemotePrefillRequest
    ) -> None:
        encoded_request = msgspec.json.encode(prefill_request)
        await self.enqueue_task(encoded_request)

    async def dequeue_prefill_request(self) -> Optional[RemotePrefillRequest]:
        encoded_request = await self.dequeue_task()
        if encoded_request is not None:
            prefill_request = msgspec.json.decode(
                encoded_request, type=RemotePrefillRequest
            )
            return prefill_request
        else:
            return None


class NixlMetadataStore:
    NIXL_METADATA_KEY = "nixl_metadata"

    def __init__(self, namespace: str, runtime: DistributedRuntime) -> None:
        self._namespace = namespace

        # TODO Remove metadata from etcd on delete
        self._stored: set[str] = set()

        self._cached: dict[str, NixlMetadata] = {}
        self._client = runtime.etcd_client()
        if self._client is None:
            raise Exception("Cannot be used with static workers")
        self._key_prefix = f"{self._namespace}/{NixlMetadataStore.NIXL_METADATA_KEY}"

    async def put(self, engine_id, metadata: NixlMetadata):
        serialized_metadata = msgspec.msgpack.encode(metadata)
        key = "/".join([self._key_prefix, engine_id])
        await self._client.kv_put(key, serialized_metadata, None)
        self._stored.add(engine_id)

    async def get(self, engine_id) -> NixlMetadata:
        try:
            if engine_id in self._cached:
                return self._cached[engine_id]

            key = "/".join([self._key_prefix, engine_id])
            key_values = await self._client.kv_get_prefix(key)
            deserialized_metadata = None

            for item in key_values:
                deserialized_metadata = msgspec.msgpack.decode(
                    item["value"], type=NixlMetadata
                )
                break

            if deserialized_metadata is None:
                raise Exception("metadata not found in etcd")

            self._cached[engine_id] = deserialized_metadata

            # TODO watch for changes and update cache

            # self._client.add_watch_callback(
            #     key,
            #     self._watch_callback,
            # )

        except Exception as e:
            raise Exception(f"Error retrieving metadata for engine {engine_id}") from e

        return deserialized_metadata


async def check_required_workers(
    workers_client: Client,
    required_workers: int,
    on_change=True,
    poll_interval=5,
    tag="",
):
    """Wait until the minimum number of workers are ready."""
    worker_ids = workers_client.endpoint_ids()
    num_workers = len(worker_ids)
    new_count = -1  # Force to print "waiting for worker" once
    while num_workers < required_workers:
        if (not on_change) or new_count != num_workers:
            num_workers = new_count if new_count >= 0 else num_workers
            print(
                f" {tag} Waiting for more workers to be ready.\n"
                f" Current: {num_workers},"
                f" Required: {required_workers}"
            )
        await asyncio.sleep(poll_interval)
        worker_ids = workers_client.endpoint_ids()
        new_count = len(worker_ids)

    print(f"Workers ready: {worker_ids}")
    return worker_ids