worker.py 8.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import os

import uvloop
from disagg_router import PyDisaggregatedRouter
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args,
)
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind

from dynamo.llm import KvMetricsPublisher
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker


class RequestHandler:
    def __init__(
        self,
        model_name: str,
        engine_client: EngineClient,
        prefill_client,
        do_remote_prefill: bool,
        disaggregated_router: PyDisaggregatedRouter = None,
    ):
        self.model_name = model_name
        self.client = engine_client
        self.prefill_client = prefill_client
        self.openai_serving_chat = None
        self.initialized = False
        self.do_remote_prefill = (
            do_remote_prefill  # remote prefill is still controlled by the router
        )
        self.disaggregated_router = disaggregated_router

        self._prefill_queue_nats_server = os.getenv(
            "NATS_SERVER", "nats://localhost:4222"
        )
        self._prefill_queue_stream_name = model_name
        vllm_logger.info(
            f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
        )

        print("RequestHandler initialized")

    def get_remote_prefill_request_callback(self):
Neelay Shah's avatar
Neelay Shah committed
69
        # TODO: integrate prefill_queue to dynamo endpoint
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        async def callback(request: RemotePrefillRequest):
            async with PrefillQueue.get_instance(
                nats_server=self._prefill_queue_nats_server,
                stream_name=self._prefill_queue_stream_name,
            ) as prefill_queue:
                await prefill_queue.enqueue_prefill_request(request)

        return callback

    @dynamo_endpoint(vLLMGenerateRequest, MyRequestOutput)
    async def generate(self, request):
        # TODO: consider prefix hit when deciding prefill locally or remotely
        if self.disaggregated_router is not None:
            disagg_router_decision = self.disaggregated_router.prefill_remote(
84
                len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
85
86
87
88
            )
        else:
            # always prefill remotely if no disaggregated router is provided
            disagg_router_decision = True
89

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        if self.do_remote_prefill and disagg_router_decision:
            remote_prefill_params = RemotePrefillParams(
                is_remote_prefill=True,
                remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
            )
            vllm_logger.info(
                f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
            )
        else:
            remote_prefill_params = None
            vllm_logger.info(
                f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
            )

        # rust HTTP requires Delta streaming
        request.sampling_params.output_kind = RequestOutputKind.DELTA

        async for response in self.client.generate(
            prompt=request.engine_prompt,
            sampling_params=request.sampling_params,
            request_id=request.request_id,
            remote_prefill_params=remote_prefill_params,
        ):
            yield MyRequestOutput(
                request_id=response.request_id,
                prompt=response.prompt,
                prompt_token_ids=response.prompt_token_ids,
                prompt_logprobs=response.prompt_logprobs,
                outputs=response.outputs,
                finished=response.finished,
            ).model_dump_json()


@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
    component = runtime.namespace("dynamo-init").component("vllm")
    await component.create_service()

    endpoint = component.endpoint("generate")

130
131
132
133
134
135
136
137
138
    if engine_args.remote_prefill:
        prefill_client = (
            await runtime.namespace("dynamo-init")
            .component("prefill")
            .endpoint("generate")
            .client()
        )
    else:
        prefill_client = None
139

140
    if engine_args.router == "kv":
141
142
143
144
        # TODO: do we need these env vars?
        VLLM_WORKER_ID = endpoint.lease_id()
        os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
        vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
145

146
147
        VLLM_KV_NAMESPACE = "dynamo-init"
        os.environ["VLLM_KV_NAMESPACE"] = str(VLLM_KV_NAMESPACE)
148

149
150
        VLLM_KV_COMPONENT = "vllm"
        os.environ["VLLM_KV_COMPONENT"] = str(VLLM_KV_COMPONENT)
151

152
        metrics_publisher = KvMetricsPublisher()
153
154
155
156
157
158
159
160

    async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
        served_model_name = (
            engine_args.served_model_name
            if engine_args.served_model_name is not None
            else "vllm"
        )

161
        if engine_args.router == "kv":
162
            engine_client.set_metrics_publisher(metrics_publisher)
163

164
165
166
            # Initially send dummy metrics to kick start,
            # vLLM will not update stat until forward pass is triggered
            metrics_publisher.publish(
167
168
169
170
171
172
173
                0,  # request_active_slots
                1024,  # request_total_slots
                0,  # kv_active_blocks
                1024,  # kv_total_blocks
                0,  # num_requests_waiting
                0.0,  # gpu_cache_usage_perc
                0.0,  # gpu_prefix_cache_hit_rate
174
            )
175

176
177
178
179
180
181
182
183
184
185
186
187
188
        if engine_args.remote_prefill:
            metadata = engine_client.nixl_metadata
            metadata_store = NixlMetadataStore("dynamo-init", runtime)
            await metadata_store.put(metadata.engine_id, metadata)

        if engine_args.conditional_disagg:
            disaggregated_router = PyDisaggregatedRouter(
                runtime,
                served_model_name,
                max_local_prefill_length=engine_args.max_local_prefill_length,
            )
        else:
            disaggregated_router = None
189

190
        endpoints = [
191
192
193
194
195
            endpoint.serve_endpoint(
                RequestHandler(
                    model_name=served_model_name,
                    engine_client=engine_client,
                    prefill_client=prefill_client,
196
                    do_remote_prefill=engine_args.remote_prefill,
197
198
                    disaggregated_router=disaggregated_router,
                ).generate
199
200
            )
        ]
201
        if engine_args.router == "kv":
202
203
            endpoints.append(metrics_publisher.create_endpoint(component))
        await asyncio.gather(*endpoints)
204
205
206
207
208
209


if __name__ == "__main__":
    uvloop.install()
    engine_args = parse_vllm_args()

210
211
212
213
    if engine_args.remote_prefill:
        if engine_args.enable_chunked_prefill is not False:
            print("Chunked prefill is not supported yet, setting to False")
            engine_args.enable_chunked_prefill = False
214

215
216
217
        if engine_args.preemption_mode != "swap":
            print("Preemption mode is not supported yet, setting to swap")
            engine_args.preemption_mode = "swap"
218

219
220
221
        if engine_args.pipeline_parallel_size != 1:
            print("Pipeline parallel size is not supported yet, setting to 1")
            engine_args.pipeline_parallel_size = 1
222
223

    asyncio.run(worker(engine_args))