worker_base.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
6
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar
7
8
9
10

import torch
import torch.nn as nn

11
from vllm.config import VllmConfig, set_current_vllm_config
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
15
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
16
from vllm.utils.import_utils import resolve_obj_by_qualname
17
from vllm.utils.system_utils import update_environment_variables
18
from vllm.v1.kv_cache_interface import KVCacheSpec
19
from vllm.v1.serial_utils import run_method
20
21

if TYPE_CHECKING:
22
23
    from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
    from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
24
25
else:
    SchedulerOutput = object
26
27
    GrammarOutput = object
    AsyncModelRunnerOutput = object
28
    ModelRunnerOutput = object
29
30
31

logger = init_logger(__name__)

32
_R = TypeVar("_R")
33

34
35
36
37
38

class WorkerBase:
    """Worker interface that allows vLLM to cleanly separate implementations for
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
39
40
41
42
43
44
45
46
47
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
48
    ) -> None:
49
50
        """
        Initialize common worker components.
51

52
53
54
55
56
        Args:
            vllm_config: Complete vLLM configuration
            local_rank: Local device index
            rank: Global rank in distributed setup
            distributed_init_method: Distributed initialization method
57
58
            is_driver_worker: Whether this worker handles driver
                responsibilities
59
        """
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
        self.kv_transfer_config = vllm_config.kv_transfer_config
        self.compilation_config = vllm_config.compilation_config

        from vllm.platforms import current_platform
74

75
        self.current_platform = current_platform
76

77
        self.parallel_config.rank = rank
78
79
80
81
82
83
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
        self.is_driver_worker = is_driver_worker

        # Device and model state
84
85
        self.device: torch.device | None = None
        self.model_runner: nn.Module | None = None
86

87
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
88
89
90
91
92
93
94
95
96
97
        """Get specifications for KV cache implementation."""
        raise NotImplementedError

    def compile_or_warm_up_model(self) -> None:
        """Prepare model for execution through compilation/warmup."""
        raise NotImplementedError

    def check_health(self) -> None:
        """Basic health check (override for device-specific checks)."""
        return
98
99
100
101
102
103
104

    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

105
106
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks."""
107
108
        raise NotImplementedError

109
110
111
112
113
    def reset_mm_cache(self) -> None:
        reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
        if callable(reset_fn):
            reset_fn()

114
115
116
117
118
119
120
121
122
123
124
    def get_model(self) -> nn.Module:
        raise NotImplementedError

    def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
        """Apply a function on the model inside this worker."""
        return fn(self.get_model())

    def load_model(self) -> None:
        """Load model onto target device."""
        raise NotImplementedError

125
126
    def execute_model(
        self, scheduler_output: SchedulerOutput
127
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
128
129
130
131
132
133
134
135
136
137
138
139
        """If this method returns None, sample_tokens should be called immediately after
        to obtain the ModelRunnerOutput.

        Note that this design may be changed in future if/when structured outputs
        parallelism is re-architected.
        """
        raise NotImplementedError

    def sample_tokens(
        self, grammar_output: GrammarOutput
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        """Should be called immediately after execute_model iff it returned None."""
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        raise NotImplementedError

    def get_cache_block_size_bytes(self) -> int:
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

    def list_loras(self) -> set[int]:
        raise NotImplementedError

    @property
    def vocab_size(self) -> int:
        """Get vocabulary size from model configuration."""
        return self.model_config.get_vocab_size()

    def shutdown(self) -> None:
        """Clean up resources held by the worker."""
        return


class WorkerWrapperBase:
    """
    This class represents one process in an executor/engine. It is responsible
    for lazily initializing the worker and handling the worker's lifecycle.
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

    def __init__(
        self,
        rpc_rank: int = 0,
182
        global_rank: int | None = None,
183
184
185
186
187
188
189
190
191
192
193
194
    ) -> None:
        """
        Initialize the worker wrapper with the given vllm_config and rpc_rank.
        Note: rpc_rank is the rank of the worker in the executor. In most cases,
        it is also the rank of the worker in the distributed group. However,
        when multiple executors work together, they can be different.
        e.g. in the case of SPMD-style offline inference with TP=2,
        users can launch 2 engines/executors, each with only 1 worker.
        All workers have rpc_rank=0, but they have different ranks in the TP
        group.
        """
        self.rpc_rank = rpc_rank
195
        self.global_rank = self.rpc_rank if global_rank is None else global_rank
196

197
198
199
        # Initialized after init_worker is called
        self.worker: WorkerBase
        self.vllm_config: VllmConfig
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    def shutdown(self) -> None:
        if self.worker is not None:
            self.worker.shutdown()

    def adjust_rank(self, rank_mapping: dict[int, int]) -> None:
        """
        Adjust the rpc_rank based on the given mapping.
        It is only used during the initialization of the executor,
        to adjust the rpc_rank of workers after we create all workers.
        """
        if self.rpc_rank in rank_mapping:
            self.rpc_rank = rank_mapping[self.rpc_rank]

    def update_environment_variables(
        self,
        envs_list: list[dict[str, str]],
    ) -> None:
        envs = envs_list[self.rpc_rank]
219
        key = "CUDA_VISIBLE_DEVICES"
220
221
222
223
224
225
226
227
228
229
230
231
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

    def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
        """
        Here we inject some common logic before initializing the worker.
        Arguments are passed to the worker class constructor.
        """
        kwargs = all_kwargs[self.rpc_rank]
232
233
234

        vllm_config: VllmConfig | None = kwargs.get("vllm_config")
        assert vllm_config is not None, (
235
236
            "vllm_config is required to initialize the worker"
        )
237
238
239
        self.vllm_config = vllm_config

        vllm_config.enable_trace_function_call_for_thread()
240
241

        from vllm.plugins import load_general_plugins
242

243
244
        load_general_plugins()

245
246
247
248
        parallel_config = vllm_config.parallel_config
        if isinstance(parallel_config.worker_cls, str):
            worker_class: type[WorkerBase] = resolve_obj_by_qualname(
                parallel_config.worker_cls
249
            )
250
251
        else:
            raise ValueError(
252
253
254
                "passing worker_cls is no longer supported. "
                "Please pass keep the class in a separate module "
                "and pass the qualified name of the class as a string."
255
            )
256
257

        if parallel_config.worker_extension_cls:
258
            worker_extension_cls = resolve_obj_by_qualname(
259
                parallel_config.worker_extension_cls
260
            )
261
262
263
264
265
266
267
268
269
            extended_calls = []
            if worker_extension_cls not in worker_class.__bases__:
                # check any conflicts between worker and worker_extension_cls
                for attr in dir(worker_extension_cls):
                    if attr.startswith("__"):
                        continue
                    assert not hasattr(worker_class, attr), (
                        f"Worker class {worker_class} already has an attribute"
                        f" {attr}, which conflicts with the worker"
270
271
                        f" extension class {worker_extension_cls}."
                    )
272
273
274
275
                    if callable(getattr(worker_extension_cls, attr)):
                        extended_calls.append(attr)
                # dynamically inherit the worker extension class
                worker_class.__bases__ = worker_class.__bases__ + (
276
277
                    worker_extension_cls,
                )
278
279
                logger.info(
                    "Injected %s into %s for extended collective_rpc calls %s",
280
281
282
283
                    worker_extension_cls,
                    worker_class,
                    extended_calls,
                )
284
285
286
287
288
289
290
291

        shared_worker_lock = kwargs.pop("shared_worker_lock", None)
        if shared_worker_lock is None:
            msg = (
                "Missing `shared_worker_lock` argument from executor. "
                "This argument is needed for mm_processor_cache_type='shm'."
            )

292
            mm_config = vllm_config.model_config.multimodal_config
293
294
295
296
297
298
299
300
            if mm_config and mm_config.mm_processor_cache_type == "shm":
                raise ValueError(msg)
            else:
                logger.warning_once(msg)

            self.mm_receiver_cache = None
        else:
            self.mm_receiver_cache = worker_receiver_cache_from_config(
301
                vllm_config,
302
303
304
305
                MULTIMODAL_REGISTRY,
                shared_worker_lock,
            )

306
307
308
309
310
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during worker initialization
            self.worker = worker_class(**kwargs)

    def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
311
        kv_cache_config = kv_cache_configs[self.global_rank]
312
        assert self.vllm_config is not None
313
314
315
316
        with set_current_vllm_config(self.vllm_config):
            self.worker.initialize_from_config(kv_cache_config)  # type: ignore

    def init_device(self):
317
        assert self.vllm_config is not None
318
319
320
321
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during device initialization
            self.worker.init_device()  # type: ignore

322
    def execute_method(self, method: str | bytes, *args, **kwargs):
323
324
325
326
327
328
329
330
331
332
333
        try:
            # method resolution order:
            # if a method is defined in this class, it will be called directly.
            # otherwise, since we define `__getattr__` and redirect attribute
            # query to `self.worker`, the method will be called on the worker.
            return run_method(self, method, args, kwargs)
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
334
335
336
337
            msg = (
                f"Error executing method {method!r}. "
                "This might cause deadlock in distributed execution."
            )
338
339
340
            logger.exception(msg)
            raise e

341
    def __getattr__(self, attr: str):
342
        return getattr(self.worker, attr)
343
344
345
346
347
348
349
350
351
352
353
354

    def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
        mm_cache = self.mm_receiver_cache
        if mm_cache is None:
            return

        for req_data in scheduler_output.scheduled_new_reqs:
            req_data.mm_features = mm_cache.get_and_update_features(
                req_data.mm_features
            )

    def execute_model(
355
        self, scheduler_output: SchedulerOutput
356
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
357
358
        self._apply_mm_cache(scheduler_output)

359
        return self.worker.execute_model(scheduler_output)
360
361
362
363
364
365
366

    def reset_mm_cache(self) -> None:
        mm_receiver_cache = self.mm_receiver_cache
        if mm_receiver_cache is not None:
            mm_receiver_cache.clear_cache()

        self.worker.reset_mm_cache()