worker_base.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
zhuwenwen's avatar
zhuwenwen committed
5
import numa
6
import time
7
from abc import abstractmethod
8

9
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar,
10
                    Union, Type)
11

12
import cloudpickle
13
import torch.nn as nn
14

15
from vllm.config import VllmConfig, set_current_vllm_config
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
from vllm.sequence import ExecuteModelRequest
19
from vllm.utils import (enable_trace_function_call_for_thread,
20
                        resolve_obj_by_qualname, run_method,
21
22
                        update_environment_variables,
                        warn_for_unimplemented_methods)
23

24
from vllm.v1.outputs import SamplerOutput
25
26
27


logger = init_logger(__name__)
28

29
30
_R = TypeVar("_R")

31

zhuwenwen's avatar
zhuwenwen committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# 设置当前进程绑定到 NUMA 节点
def bind_to_numa(local_rank):
    env_str = f"VLLM_RANK{local_rank}_NUMA"
    node_count = numa.get_max_node() + 1
    numa_node = int(os.getenv(env_str, -1))

    # 未配置环境变量或配置错误则不做绑定,TODO:根据topo自动绑定方案
    if numa_node < 0:
        logger.warning("%s is unset or set incorrectly, vllm will not bind to numa! %s = %d", env_str, env_str, numa_node)
        return

    if numa_node > numa.get_max_node():
        raise ValueError(f"NUMA node {numa_node} is not available.")

    numa.bind([numa_node])   
    
    
49
50
@warn_for_unimplemented_methods
class WorkerBase:
51
    """Worker interface that allows vLLM to cleanly separate implementations for
52
53
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
54
    """
55
    # TODO
56
    tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
72
        self.kv_transfer_config = vllm_config.kv_transfer_config
73
        self.compilation_config = vllm_config.compilation_config
74
75
        from vllm.platforms import current_platform
        self.current_platform = current_platform
76

77
78
79
80
81
82
83
84
85
86
87
88
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

89
90
91
    def get_model(self) -> nn.Module:
        raise NotImplementedError

92
93
94
95
    def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R:
        """Apply a function on the model inside this worker."""
        return fn(self.get_model())

96
97
98
99
100
101
102
103
104
105
    def load_model(self) -> None:
        """Load model onto target device."""
        raise NotImplementedError

    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> Optional[List[SamplerOutput]]:
        raise NotImplementedError

106
107
108
109
110
111
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
112
113
114
115
116
        with self.current_platform.inference_mode():
            while True:
                output = self.execute_model(execute_model_req=None)
                if output is None:
                    return None
117

118
119
120
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.
121

122
123
124
125
126
127
128
129
        The implementation may run profiling or other heuristics to determine
        the size of caches.

        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
130
131
        raise NotImplementedError

132
    def get_cache_block_size_bytes(self) -> int:
133
134
135
136
137
138
139
140
141
142
143
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

144
145
146
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

147
    def list_loras(self) -> Set[int]:
148
        raise NotImplementedError
149
    
zhuwenwen's avatar
zhuwenwen committed
150
151
152
153
    # @property
    # @abstractmethod
    # def cache_engines(self) -> Optional[List[CacheEngine]]:
    #     raise NotImplementedError
154

155
156
157
158
159
    @property
    def vocab_size(self) -> int:
        """Get vocabulary size from model configuration."""
        return self.model_config.get_vocab_size()

160
161
162
163
    def shutdown(self) -> None:
        """Clean up resources held by the worker."""
        return

164

165
166
class WorkerWrapperBase:
    """
167
168
    This class represents one process in an executor/engine. It is responsible
    for lazily initializing the worker and handling the worker's lifecycle.
169
170
171
172
173
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

174
175
    def __init__(
        self,
176
        vllm_config: VllmConfig,
177
        rpc_rank: int = 0,
178
    ) -> None:
179
180
181
182
183
184
185
186
187
188
189
        """
        Initialize the worker wrapper with the given vllm_config and rpc_rank.
        Note: rpc_rank is the rank of the worker in the executor. In most cases,
        it is also the rank of the worker in the distributed group. However,
        when multiple executors work together, they can be different.
        e.g. in the case of SPMD-style offline inference with TP=2,
        users can launch 2 engines/executors, each with only 1 worker.
        All workers have rpc_rank=0, but they have different ranks in the TP
        group.
        """
        self.rpc_rank = rpc_rank
190
        self.worker: Optional[WorkerBase] = None
191
        self.vllm_config: Optional[VllmConfig] = None
192
193
194
195
        # do not store this `vllm_config`, `init_worker` will set the final
        # one. TODO: investigate if we can remove this field in
        # `WorkerWrapperBase`, `init_cached_hf_modules` should be
        # unnecessary now.
196
197
198
199
200
201
202
203
        if vllm_config.model_config is not None:
            # it can be None in tests
            trust_remote_code = vllm_config.model_config.trust_remote_code
            if trust_remote_code:
                # note: lazy import to avoid importing torch before initializing
                from vllm.utils import init_cached_hf_modules
                init_cached_hf_modules()

204
205
206
207
    def shutdown(self) -> None:
        if self.worker is not None:
            self.worker.shutdown()

208
209
    def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
        """
210
        Adjust the rpc_rank based on the given mapping.
211
        It is only used during the initialization of the executor,
212
        to adjust the rpc_rank of workers after we create all workers.
213
        """
214
215
        if self.rpc_rank in rank_mapping:
            self.rpc_rank = rank_mapping[self.rpc_rank]
216

217
218
    def update_environment_variables(self, envs_list: List[Dict[str,
                                                                str]]) -> None:
219
        envs = envs_list[self.rpc_rank]
220
221
222
223
224
225
226
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

227
    def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
228
        """
229
        Here we inject some common logic before initializing the worker.
230
231
        Arguments are passed to the worker class constructor.
        """
232
        kwargs = all_kwargs[self.rpc_rank]
233
        self.vllm_config = kwargs.get("vllm_config")
234
235
        assert self.vllm_config is not None, (
            "vllm_config is required to initialize the worker")
236
        enable_trace_function_call_for_thread(self.vllm_config)
237

238
239
240
        from vllm.plugins import load_general_plugins
        load_general_plugins()

241
242
243
244
        if isinstance(self.vllm_config.parallel_config.worker_cls, str):
            worker_class = resolve_obj_by_qualname(
                self.vllm_config.parallel_config.worker_cls)
        else:
245
246
247
248
249
250
            logger.warning(
                "passing worker_cls as a class object is strongly deprecated,"
                " as the serialization of class objects can be tricky and"
                " error-prone. To be safe, please keep the class in a separate"
                " module and pass the qualified name of the class as a string."
            )
251
252
253
254
            assert isinstance(self.vllm_config.parallel_config.worker_cls,
                              bytes)
            worker_class = cloudpickle.loads(
                self.vllm_config.parallel_config.worker_cls)
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        if self.vllm_config.parallel_config.worker_extension_cls:
            worker_extension_cls = resolve_obj_by_qualname(
                self.vllm_config.parallel_config.worker_extension_cls)
            extended_calls = []
            if worker_extension_cls not in worker_class.__bases__:
                # check any conflicts between worker and worker_extension_cls
                for attr in dir(worker_extension_cls):
                    if attr.startswith("__"):
                        continue
                    assert not hasattr(worker_class, attr), (
                        f"Worker class {worker_class} already has an attribute"
                        f" {attr}, which conflicts with the worker"
                        f" extension class {worker_extension_cls}.")
                    if callable(getattr(worker_extension_cls, attr)):
                        extended_calls.append(attr)
                # dynamically inherit the worker extension class
                worker_class.__bases__ = worker_class.__bases__ + (
                    worker_extension_cls, )
                logger.info(
                    "Injected %s into %s for extended collective_rpc calls %s",
                    worker_extension_cls, worker_class, extended_calls)
276
277
278
279
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during worker initialization
            self.worker = worker_class(**kwargs)
            assert self.worker is not None
zhuwenwen's avatar
zhuwenwen committed
280
281
282
283
284
285
286
287
288
289
            
        VLLM_NUMA_BIND = int(os.getenv("VLLM_NUMA_BIND", 1))
        if VLLM_NUMA_BIND > 0:
            # 绑定当前进程到指定 NUMA 节点
            bind_to_numa(kwargs['local_rank'])

            pid = os.getpid()
            logger.info("########## %d process(rank%s) is running on CPU(s): %s", pid, str(kwargs['local_rank']), str(os.sched_getaffinity(pid)))
            logger.info("########## %d process(rank%s) is running on memnode(s): %s", pid, str(kwargs['local_rank']), str(numa.get_membind()))

290

291
292
    def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
        kv_cache_config = kv_cache_configs[self.rpc_rank]
293
294
        with set_current_vllm_config(self.vllm_config):
            self.worker.initialize_from_config(kv_cache_config)  # type: ignore
295

296
297
298
299
300
    def init_device(self):
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during device initialization
            self.worker.init_device()  # type: ignore

301
    def execute_method(self, method: Union[str, bytes], *args, **kwargs):
302
        try:
303
304
305
306
307
            # method resolution order:
            # if a method is defined in this class, it will be called directly.
            # otherwise, since we define `__getattr__` and redirect attribute
            # query to `self.worker`, the method will be called on the worker.
            return run_method(self, method, args, kwargs)
308
309
310
311
312
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
313
            msg = (f"Error executing method {method!r}. "
314
315
316
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e
317

318
319
    def __getattr__(self, attr):
        return getattr(self.worker, attr)