worker.py 6.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import logging
import os
import signal
import socket
from typing import Optional

from utils.args import parse_vllm_args
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args,
)

30
31
32
33
34
35
36
# Additional vLLM imports for DP worker
from vllm.usage.usage_lib import UsageContext
from vllm.utils import get_tcp_uri
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor

37
from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
38
39
40
41
42
43
44
45
46

logger = logging.getLogger(__name__)


class VllmBaseWorker:
    def __init__(self):
        class_name = self.__class__.__name__
        self.engine_args = parse_vllm_args(class_name, "")

47
48
        signal.signal(signal.SIGTERM, self.graceful_shutdown)
        signal.signal(signal.SIGINT, self.graceful_shutdown)
49

50
        self.set_side_channel_host_and_port()
51
52
53
54
55
56
57
58
59
60
61
62

    async def async_init(self):
        self._engine_context = build_async_engine_client_from_engine_args(
            self.engine_args
        )
        if self._engine_context is not None:
            self.engine_client = await self._engine_context.__aenter__()
        else:
            raise RuntimeError("Failed to initialize engine client")

        logger.info("VllmWorker has been initialized")

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    def graceful_shutdown(self, signum, frame):
        """
        Gracefully shutdown the worker by shutting down the dynamo runtime.
        This will
            1. disable the generate endpoint so no new requests are accepted.
            2. wait until all in-flight requests are completed.
            3. finish the awaiting for the endpoint service.
            4. rely on python's garbage collection to clean up the GPU.
        """
        logger.info("Shutting down dynamo runtime...")
        dynamo_context["runtime"].shutdown()
        logger.info("Dynamo runtime shutdown complete.")

    def shutdown_vllm_worker(self, signum, frame):
        """Shutdown the worker immediately by killing the background loop"""
78
79
80
81
82
83
84
85
86
        loop = asyncio.get_event_loop()
        try:
            self.engine_client.close()
            logger.info("VllmWorker shutdown complete")
        except Exception as e:
            logger.error(f"Error during shutdown: {e}")
        finally:
            loop.stop()

87
    @endpoint()
88
89
90
91
92
93
94
95
    async def generate(self, request: vLLMGenerateRequest):
        gen = self.engine_client.generate(
            prompt=request.prompt,
            sampling_params=request.sampling_params,
            request_id=request.request_id,
        )

        async for response in gen:
96
            logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
97
98
99
100
101
102
103
104
105
106
107
            yield MyRequestOutput(
                request_id=response.request_id,
                prompt=response.prompt,
                prompt_token_ids=response.prompt_token_ids,
                prompt_logprobs=response.prompt_logprobs,
                outputs=response.outputs,
                finished=response.finished,
                metrics=response.metrics,
                kv_transfer_params=response.kv_transfer_params,
            ).model_dump_json()

108
109
110
    def set_side_channel_host_and_port(
        self, hostname: Optional[str] = None, port: Optional[int] = None
    ):
111
112
113
        """vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
        This sets the port number for the side channel.
        """
114
        if hostname is None:
115
            hostname = socket.gethostname()
116
117
118
119
        if port is None:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(("", 0))  # Bind to a free port provided by the host.
                port = s.getsockname()[1]  # Get the port number assigned.
120
121
        logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_HOST to %s", hostname)
        os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_PORT to %s", port)
        os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)


@service(
    dynamo={
        "enabled": True,
        "namespace": "dynamo",
    },
    resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
    workers=1,
)
class VllmPrefillWorker(VllmBaseWorker):
    @async_on_start
    async def async_init(self):
        await super().async_init()
        logger.info("VllmPrefillWorker has been initialized")


@service(
    dynamo={
        "enabled": True,
        "namespace": "dynamo",
    },
    resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
    workers=1,
)
class VllmDecodeWorker(VllmBaseWorker):
    @async_on_start
    async def async_init(self):
        await super().async_init()
        logger.info("VllmDecodeWorker has been initialized")
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195


@service(
    dynamo={
        "enabled": True,
        "namespace": "dynamo",
    },
    resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
    workers=1,
)
class VllmDpWorker(VllmBaseWorker):
    @async_on_start
    async def async_init(self):
        vllm_config = self.engine_args.create_engine_config(
            usage_context=UsageContext.OPENAI_API_SERVER
        )

        parallel_config = vllm_config.parallel_config
        local_engine_count = parallel_config.data_parallel_size_local
        host = parallel_config.data_parallel_master_ip
        port = self.engine_args.data_parallel_rpc_port  # add to config too
        handshake_address = get_tcp_uri(host, port)

        self.engine_manager = CoreEngineProcManager(
            target_fn=EngineCoreProc.run_engine_core,
            local_engine_count=local_engine_count,
            start_index=self.engine_args.data_parallel_start_rank,
            local_start_index=0,
            vllm_config=vllm_config,
            on_head_node=False,
            handshake_address=handshake_address,
            executor_class=Executor.get_class(vllm_config),
            log_stats=not self.engine_args.disable_log_stats,
        )

    def shutdown_vllm_engine(self, signum, frame):
        """Shutdown the engine manager"""
        try:
            self.engine_manager.join_first()
        finally:
            logger.info("Shutting down.")
            self.engine_manager.close()