"vllm/model_executor/models/mistral.py" did not exist on "88c0268a18f1c724d59a624364635d5c7ac39408"
worker_base.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar
6
7
8
9

import torch
import torch.nn as nn

10
from vllm.config import VllmConfig, set_current_vllm_config
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
from vllm.multimodal import MULTIMODAL_REGISTRY
14
from vllm.tracing import instrument
15
from vllm.utils.import_utils import resolve_obj_by_qualname
16
from vllm.utils.system_utils import update_environment_variables
17
from vllm.v1.kv_cache_interface import KVCacheSpec
18
19

if TYPE_CHECKING:
20
21
    from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
    from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
22
23
else:
    SchedulerOutput = object
24
25
    GrammarOutput = object
    AsyncModelRunnerOutput = object
26
    ModelRunnerOutput = object
27
28
29

logger = init_logger(__name__)

30
_R = TypeVar("_R")
31

32
33
34
35
36

class WorkerBase:
    """Worker interface that allows vLLM to cleanly separate implementations for
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
37
38
39
40
41
42
43
44
45
    """

    def __init__(
        self,
        vllm_config: VllmConfig,
        local_rank: int,
        rank: int,
        distributed_init_method: str,
        is_driver_worker: bool = False,
46
    ) -> None:
47
48
        """
        Initialize common worker components.
49

50
51
52
53
54
        Args:
            vllm_config: Complete vLLM configuration
            local_rank: Local device index
            rank: Global rank in distributed setup
            distributed_init_method: Distributed initialization method
55
56
            is_driver_worker: Whether this worker handles driver
                responsibilities
57
        """
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
        self.kv_transfer_config = vllm_config.kv_transfer_config
        self.compilation_config = vllm_config.compilation_config

        from vllm.platforms import current_platform
72

73
        self.current_platform = current_platform
74

75
        self.parallel_config.rank = rank
76
77
78
79
80
81
        self.local_rank = local_rank
        self.rank = rank
        self.distributed_init_method = distributed_init_method
        self.is_driver_worker = is_driver_worker

        # Device and model state
82
83
        self.device: torch.device | None = None
        self.model_runner: nn.Module | None = None
84

85
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
86
87
88
        """Get specifications for KV cache implementation."""
        raise NotImplementedError

89
90
91
92
93
94
    def compile_or_warm_up_model(self) -> float:
        """Prepare model for execution through compilation/warmup.

        Returns:
            The accumulated compilation time in seconds.
        """
95
96
97
98
99
        raise NotImplementedError

    def check_health(self) -> None:
        """Basic health check (override for device-specific checks)."""
        return
100
101
102
103
104
105
106

    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

107
108
    def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks."""
109
110
        raise NotImplementedError

111
112
113
114
115
    def reset_mm_cache(self) -> None:
        reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
        if callable(reset_fn):
            reset_fn()

116
117
118
119
120
121
122
    def get_model(self) -> nn.Module:
        raise NotImplementedError

    def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
        """Apply a function on the model inside this worker."""
        return fn(self.get_model())

123
124
125
126
127
128
    def get_model_inspection(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        from vllm.model_inspection import format_model_inspection

        return format_model_inspection(self.get_model())

129
130
131
132
    def load_model(self) -> None:
        """Load model onto target device."""
        raise NotImplementedError

133
134
    def execute_model(
        self, scheduler_output: SchedulerOutput
135
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
136
137
138
139
140
141
142
143
144
145
146
147
        """If this method returns None, sample_tokens should be called immediately after
        to obtain the ModelRunnerOutput.

        Note that this design may be changed in future if/when structured outputs
        parallelism is re-architected.
        """
        raise NotImplementedError

    def sample_tokens(
        self, grammar_output: GrammarOutput
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput:
        """Should be called immediately after execute_model iff it returned None."""
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        raise NotImplementedError

    def get_cache_block_size_bytes(self) -> int:
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

    def list_loras(self) -> set[int]:
        raise NotImplementedError

    @property
    def vocab_size(self) -> int:
        """Get vocabulary size from model configuration."""
        return self.model_config.get_vocab_size()

    def shutdown(self) -> None:
        """Clean up resources held by the worker."""
        return


class WorkerWrapperBase:
    """
    This class represents one process in an executor/engine. It is responsible
    for lazily initializing the worker and handling the worker's lifecycle.
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

    def __init__(
        self,
        rpc_rank: int = 0,
190
        global_rank: int | None = None,
191
192
193
194
195
196
197
198
199
200
201
202
    ) -> None:
        """
        Initialize the worker wrapper with the given vllm_config and rpc_rank.
        Note: rpc_rank is the rank of the worker in the executor. In most cases,
        it is also the rank of the worker in the distributed group. However,
        when multiple executors work together, they can be different.
        e.g. in the case of SPMD-style offline inference with TP=2,
        users can launch 2 engines/executors, each with only 1 worker.
        All workers have rpc_rank=0, but they have different ranks in the TP
        group.
        """
        self.rpc_rank = rpc_rank
203
        self.global_rank = self.rpc_rank if global_rank is None else global_rank
204

205
206
207
        # Initialized after init_worker is called
        self.worker: WorkerBase
        self.vllm_config: VllmConfig
208
209
210
211
212
213
214
215
216
217
218
219

    def shutdown(self) -> None:
        if self.worker is not None:
            self.worker.shutdown()

    def update_environment_variables(
        self,
        envs_list: list[dict[str, str]],
    ) -> None:
        envs = envs_list[self.rpc_rank]
        update_environment_variables(envs)

220
    @instrument(span_name="Worker init")
221
222
223
224
225
226
    def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
        """
        Here we inject some common logic before initializing the worker.
        Arguments are passed to the worker class constructor.
        """
        kwargs = all_kwargs[self.rpc_rank]
227
228
229

        vllm_config: VllmConfig | None = kwargs.get("vllm_config")
        assert vllm_config is not None, (
230
231
            "vllm_config is required to initialize the worker"
        )
232
233
234
        self.vllm_config = vllm_config

        vllm_config.enable_trace_function_call_for_thread()
235
236

        from vllm.plugins import load_general_plugins
237

238
239
        load_general_plugins()

240
241
242
243
        parallel_config = vllm_config.parallel_config
        if isinstance(parallel_config.worker_cls, str):
            worker_class: type[WorkerBase] = resolve_obj_by_qualname(
                parallel_config.worker_cls
244
            )
245
246
        else:
            raise ValueError(
247
248
249
                "passing worker_cls is no longer supported. "
                "Please pass keep the class in a separate module "
                "and pass the qualified name of the class as a string."
250
            )
251
252

        if parallel_config.worker_extension_cls:
253
            worker_extension_cls = resolve_obj_by_qualname(
254
                parallel_config.worker_extension_cls
255
            )
256
257
258
259
260
261
262
263
264
            extended_calls = []
            if worker_extension_cls not in worker_class.__bases__:
                # check any conflicts between worker and worker_extension_cls
                for attr in dir(worker_extension_cls):
                    if attr.startswith("__"):
                        continue
                    assert not hasattr(worker_class, attr), (
                        f"Worker class {worker_class} already has an attribute"
                        f" {attr}, which conflicts with the worker"
265
266
                        f" extension class {worker_extension_cls}."
                    )
267
268
269
270
                    if callable(getattr(worker_extension_cls, attr)):
                        extended_calls.append(attr)
                # dynamically inherit the worker extension class
                worker_class.__bases__ = worker_class.__bases__ + (
271
272
                    worker_extension_cls,
                )
273
274
                logger.info(
                    "Injected %s into %s for extended collective_rpc calls %s",
275
276
277
278
                    worker_extension_cls,
                    worker_class,
                    extended_calls,
                )
279
280
281
282
283
284
285
286

        shared_worker_lock = kwargs.pop("shared_worker_lock", None)
        if shared_worker_lock is None:
            msg = (
                "Missing `shared_worker_lock` argument from executor. "
                "This argument is needed for mm_processor_cache_type='shm'."
            )

287
            mm_config = vllm_config.model_config.multimodal_config
288
289
290
291
292
293
294
            if mm_config and mm_config.mm_processor_cache_type == "shm":
                raise ValueError(msg)
            else:
                logger.warning_once(msg)

            self.mm_receiver_cache = None
        else:
295
296
297
298
299
            self.mm_receiver_cache = (
                MULTIMODAL_REGISTRY.worker_receiver_cache_from_config(
                    vllm_config,
                    shared_worker_lock,
                )
300
301
            )

302
303
304
305
306
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during worker initialization
            self.worker = worker_class(**kwargs)

    def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
307
        kv_cache_config = kv_cache_configs[self.global_rank]
308
        assert self.vllm_config is not None
309
310
311
312
        with set_current_vllm_config(self.vllm_config):
            self.worker.initialize_from_config(kv_cache_config)  # type: ignore

    def init_device(self):
313
        assert self.vllm_config is not None
314
315
316
317
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during device initialization
            self.worker.init_device()  # type: ignore

318
    def __getattr__(self, attr: str):
319
        return getattr(self.worker, attr)
320
321
322
323
324
325
326
327
328
329
330
331

    def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
        mm_cache = self.mm_receiver_cache
        if mm_cache is None:
            return

        for req_data in scheduler_output.scheduled_new_reqs:
            req_data.mm_features = mm_cache.get_and_update_features(
                req_data.mm_features
            )

    def execute_model(
332
        self, scheduler_output: SchedulerOutput
333
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
334
335
        self._apply_mm_cache(scheduler_output)

336
        return self.worker.execute_model(scheduler_output)
337
338
339
340
341
342
343

    def reset_mm_cache(self) -> None:
        mm_receiver_cache = self.mm_receiver_cache
        if mm_receiver_cache is not None:
            mm_receiver_cache.clear_cache()

        self.worker.reset_mm_cache()