worker.py 8.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import os

import uvloop
from disagg_router import PyDisaggregatedRouter
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args,
)
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind

from dynamo.llm import KvMetricsPublisher
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker


class RequestHandler:
    def __init__(
        self,
        model_name: str,
        engine_client: EngineClient,
        prefill_client,
        do_remote_prefill: bool,
        disaggregated_router: PyDisaggregatedRouter = None,
    ):
        self.model_name = model_name
        self.client = engine_client
        self.prefill_client = prefill_client
        self.openai_serving_chat = None
        self.initialized = False
        self.do_remote_prefill = (
            do_remote_prefill  # remote prefill is still controlled by the router
        )
        self.disaggregated_router = disaggregated_router

        self._prefill_queue_nats_server = os.getenv(
            "NATS_SERVER", "nats://localhost:4222"
        )
        self._prefill_queue_stream_name = model_name
        vllm_logger.info(
63
64
65
            "Prefill queue: %s:%s",
            self._prefill_queue_nats_server,
            self._prefill_queue_stream_name,
66
67
68
69
70
        )

        print("RequestHandler initialized")

    def get_remote_prefill_request_callback(self):
Neelay Shah's avatar
Neelay Shah committed
71
        # TODO: integrate prefill_queue to dynamo endpoint
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        async def callback(request: RemotePrefillRequest):
            async with PrefillQueue.get_instance(
                nats_server=self._prefill_queue_nats_server,
                stream_name=self._prefill_queue_stream_name,
            ) as prefill_queue:
                await prefill_queue.enqueue_prefill_request(request)

        return callback

    @dynamo_endpoint(vLLMGenerateRequest, MyRequestOutput)
    async def generate(self, request):
        # TODO: consider prefix hit when deciding prefill locally or remotely
        if self.disaggregated_router is not None:
            disagg_router_decision = self.disaggregated_router.prefill_remote(
86
                len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
87
88
89
90
            )
        else:
            # always prefill remotely if no disaggregated router is provided
            disagg_router_decision = True
91

92
93
94
95
96
        if self.do_remote_prefill and disagg_router_decision:
            remote_prefill_params = RemotePrefillParams(
                is_remote_prefill=True,
                remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
            )
97
98
99
100
            vllm_logger.debug(
                "Prefilling remotely for request %s with length %s",
                request.request_id,
                len(request.engine_prompt["prompt_token_ids"]),
101
102
103
            )
        else:
            remote_prefill_params = None
104
105
106
107
            vllm_logger.debug(
                "Prefilling locally for request %s with length %s",
                request.request_id,
                len(request.engine_prompt["prompt_token_ids"]),
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            )

        # rust HTTP requires Delta streaming
        request.sampling_params.output_kind = RequestOutputKind.DELTA

        async for response in self.client.generate(
            prompt=request.engine_prompt,
            sampling_params=request.sampling_params,
            request_id=request.request_id,
            remote_prefill_params=remote_prefill_params,
        ):
            yield MyRequestOutput(
                request_id=response.request_id,
                prompt=response.prompt,
                prompt_token_ids=response.prompt_token_ids,
                prompt_logprobs=response.prompt_logprobs,
                outputs=response.outputs,
                finished=response.finished,
            ).model_dump_json()


@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
    component = runtime.namespace("dynamo-init").component("vllm")
    await component.create_service()

    endpoint = component.endpoint("generate")

136
137
138
139
140
141
142
143
144
    if engine_args.remote_prefill:
        prefill_client = (
            await runtime.namespace("dynamo-init")
            .component("prefill")
            .endpoint("generate")
            .client()
        )
    else:
        prefill_client = None
145

146
    if engine_args.router == "kv":
147
148
149
        # TODO: do we need these env vars?
        VLLM_WORKER_ID = endpoint.lease_id()
        os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
150
        vllm_logger.info("Generate endpoint ID: %s", VLLM_WORKER_ID)
151

152
153
        VLLM_KV_NAMESPACE = "dynamo-init"
        os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
154

155
156
        VLLM_KV_COMPONENT = "vllm"
        os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
157

158
        metrics_publisher = KvMetricsPublisher()
159
160
161
162
163
164
165
166

    async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
        served_model_name = (
            engine_args.served_model_name
            if engine_args.served_model_name is not None
            else "vllm"
        )

167
        if engine_args.router == "kv":
168
            engine_client.set_metrics_publisher(metrics_publisher)
169

170
171
172
            # Initially send dummy metrics to kick start,
            # vLLM will not update stat until forward pass is triggered
            metrics_publisher.publish(
173
174
175
176
177
178
179
                0,  # request_active_slots
                1024,  # request_total_slots
                0,  # kv_active_blocks
                1024,  # kv_total_blocks
                0,  # num_requests_waiting
                0.0,  # gpu_cache_usage_perc
                0.0,  # gpu_prefix_cache_hit_rate
180
            )
181

182
183
184
185
186
187
188
189
190
191
192
193
194
        if engine_args.remote_prefill:
            metadata = engine_client.nixl_metadata
            metadata_store = NixlMetadataStore("dynamo-init", runtime)
            await metadata_store.put(metadata.engine_id, metadata)

        if engine_args.conditional_disagg:
            disaggregated_router = PyDisaggregatedRouter(
                runtime,
                served_model_name,
                max_local_prefill_length=engine_args.max_local_prefill_length,
            )
        else:
            disaggregated_router = None
195

196
        endpoints = [
197
198
199
200
201
            endpoint.serve_endpoint(
                RequestHandler(
                    model_name=served_model_name,
                    engine_client=engine_client,
                    prefill_client=prefill_client,
202
                    do_remote_prefill=engine_args.remote_prefill,
203
204
                    disaggregated_router=disaggregated_router,
                ).generate
205
206
            )
        ]
207
        if engine_args.router == "kv":
208
209
            endpoints.append(metrics_publisher.create_endpoint(component))
        await asyncio.gather(*endpoints)
210
211
212
213
214
215


if __name__ == "__main__":
    uvloop.install()
    engine_args = parse_vllm_args()

216
217
218
219
    if engine_args.remote_prefill:
        if engine_args.enable_chunked_prefill is not False:
            print("Chunked prefill is not supported yet, setting to False")
            engine_args.enable_chunked_prefill = False
220

221
222
223
        if engine_args.preemption_mode != "swap":
            print("Preemption mode is not supported yet, setting to swap")
            engine_args.preemption_mode = "swap"
224

225
226
227
        if engine_args.pipeline_parallel_size != 1:
            print("Pipeline parallel size is not supported yet, setting to 1")
            engine_args.pipeline_parallel_size = 1
228
229

    asyncio.run(worker(engine_args))