worker_base.py 12.6 KB
Newer Older
1
import dataclasses
2
3
import importlib
import os
4
from abc import ABC, abstractmethod
5
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
6

7
8
import torch

9
from vllm.distributed import broadcast_tensor_dict, get_pp_group
10
from vllm.logger import init_logger
11
from vllm.lora.request import LoRARequest
12
13
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
                           SamplerOutput)
14
from vllm.utils import (enable_trace_function_call_for_thread,
15
                        update_environment_variables)
16
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
17
18

logger = init_logger(__name__)
19
20
21
22


class WorkerBase(ABC):
    """Worker interface that allows vLLM to cleanly separate implementations for
23
24
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
25
26
27
28
29
30
31
32
33
34
    """

    @abstractmethod
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    @abstractmethod
35
    def determine_num_available_blocks(self) -> Tuple[int, int]:
36
37
38
39
40
41
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        The implementation may run profiling or other heuristics to determine
        the size of caches.

42
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
43
44
45
46
47
48
49
50
51
52
53
54
55
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

56
57
58
59
60
61
62
63
64
65
66
67
    @torch.inference_mode()
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
        while True:
            output = self.execute_model(execute_model_req=None)
            if output is None:
                return None

68
    @abstractmethod
69
    def execute_model(
70
71
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
72
    ) -> Optional[List[SamplerOutput]]:
73
74
75
        raise NotImplementedError

    @abstractmethod
76
    def get_cache_block_size_bytes(self) -> int:
77
78
79
80
81
82
83
84
85
86
87
88
89
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

90
91
92
93
    @abstractmethod
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

94
    @abstractmethod
95
    def list_loras(self) -> Set[int]:
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        raise NotImplementedError


class LoraNotSupportedWorkerBase(WorkerBase):
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

110
111
112
113
    def pin_lora(self, lora_id: int) -> bool:
        return ValueError(
            f"{type(self)} does not support LoRA")  # type: ignore

114
    def list_loras(self) -> Set[int]:
115
        raise ValueError(f"{type(self)} does not support LoRA")
116
117


118
119
120
121
122
123
124
125
126
127
@dataclasses.dataclass(frozen=True)
class WorkerInput:
    """Local inputs to each worker. May contain device-specific data. These
    fields should be broadcastable to other workers.
    """

    num_seq_groups: Optional[int] = None
    blocks_to_swap_in: Optional[torch.Tensor] = None
    blocks_to_swap_out: Optional[torch.Tensor] = None
    blocks_to_copy: Optional[torch.Tensor] = None
128
    virtual_engine: int = 0
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["WorkerInput"],
        tensor_dict: Dict[str, Any],
    ) -> "WorkerInput":
        """
        Pop fields from the given tensor_dict and populate a new instance of
        WorkerInput.
        """
        return cls(
            num_seq_groups=tensor_dict.pop("num_seq_groups"),
            blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
            blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
            blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
144
            virtual_engine=tensor_dict["virtual_engine"],
145
146
147
148
149
150
151
152
153
154
155
156
        )

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        """
        Extract broadcastable fields.
        """
        tensor_dict = {
            "num_seq_groups": self.num_seq_groups,
            "blocks_to_swap_in": self.blocks_to_swap_in,
            "blocks_to_swap_out": self.blocks_to_swap_out,
            "blocks_to_copy": self.blocks_to_copy,
157
            "virtual_engine": self.virtual_engine,
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        }

        return tensor_dict


class LocalOrDistributedWorkerBase(WorkerBase):
    """
    Partial implementation of WorkerBase that has a default `execute_model`
    definition to perform metadata transfer between workers when in distributed
    mode. Subclasses of this interface should use model runners that inherit
    from ModelRunnerBase, and should only need to implement worker-local logic.
    If custom control plane logic is needed to transfer metadata, or if the
    model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
    """
    is_driver_worker: bool
    model_runner: ModelRunnerBase

    @property
    @abstractmethod
    def do_metadata_broadcast(self) -> bool:
        """
        Used by the default `execute_model` to check whether broadcast is
        needed to transfer request inputs from the driver worker to other
        workers in the TP group. If WorkerBase subclass only supports
        single-worker execution, then this method should return False.
        """
        raise NotImplementedError

    @property
    @abstractmethod
188
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
189
        """
190
191
192
193
194
        Gets the list of kv caches to pass to the worker's model runner. Each
        element in the list is a kv cache corresponding to a particular virtual
        engine (PP stream). Used by the default `execute_model`. If the worker's
        model runner does not follow the ModelRunnerBase interface, then inherit
        from WorkerBase instead.
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        """
        Prepare the inputs to WorkerBase.execute_worker from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @abstractmethod
    def execute_worker(self, worker_input: WorkerInput) -> None:
        """
        Process an execution request.
        """
        raise NotImplementedError

    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> Optional[List[SamplerOutput]]:
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
        if self.is_driver_worker:
            if execute_model_req is None:
                if self.do_metadata_broadcast:
                    # This signals that there's no more requests to process for
                    # now. All workers are running infinite loop with
                    # broadcast_tensor_dict, and it stops the loop when the
                    # driver broadcasts an empty input. Send an empty input to
                    # notify all other workers to stop their execution loop.
                    broadcast_tensor_dict({}, src=0)
                return None

            worker_input: WorkerInput = self.prepare_worker_input(
                execute_model_req=execute_model_req)
            model_input: ModelRunnerInputBase = (
                self.model_runner.prepare_model_input(
236
                    execute_model_req.seq_group_metadata_list,
Mor Zusman's avatar
Mor Zusman committed
237
238
                    execute_model_req.virtual_engine,
                    execute_model_req.finished_requests_ids))
239
            num_steps = execute_model_req.num_steps
240
241
242
243
244

            if self.do_metadata_broadcast:
                broadcast_data = worker_input.as_broadcastable_tensor_dict()
                broadcast_data.update(
                    model_input.as_broadcastable_tensor_dict())
245
                broadcast_data["num_steps"] = num_steps
246
247
248
249
250
251
252
                broadcast_tensor_dict(broadcast_data, src=0)
        else:
            assert self.do_metadata_broadcast
            broadcast_data = broadcast_tensor_dict(src=0)
            if not broadcast_data:
                return None

253
            num_steps = broadcast_data.pop("num_steps")
254
255
256
257
258
259
260
261
262
263
264
265
            worker_input = WorkerInput.from_broadcasted_tensor_dict(
                broadcast_data)
            model_input = (
                self.model_runner.
                make_model_input_from_broadcasted_tensor_dict(broadcast_data))

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
                get_pp_group().recv_tensor_dict())

        output = self.model_runner.execute_model(
            model_input, self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None, intermediate_tensors,
            num_steps)

        if not get_pp_group().is_last_rank:
            get_pp_group().send_tensor_dict(output.tensors)
            return [None]

        # Worker only supports single-step execution. Wrap the output in a
        # list to conform to interface.
        return output
283
284


285
286
287
288
289
290
291
292
293
class WorkerWrapperBase:
    """
    The whole point of this class is to lazily initialize the worker.
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

    def __init__(self,
294
295
                 worker_module_name: str,
                 worker_class_name: str,
296
                 trust_remote_code: bool = False) -> None:
297
298
299
        self.worker_module_name = worker_module_name
        self.worker_class_name = worker_class_name
        self.worker = None
300
301
302
303
        if trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
304

305
306
    @staticmethod
    def update_environment_variables(envs: Dict[str, str]) -> None:
307
308
309
310
311
312
313
314
315
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

    def init_worker(self, *args, **kwargs):
        """
316
        Here we inject some common logic before initializing the worker.
317
318
        Arguments are passed to the worker class constructor.
        """
319
        enable_trace_function_call_for_thread()
320

321
322
323
        # see https://github.com/NVIDIA/nccl/issues/1234
        os.environ['NCCL_CUMEM_ENABLE'] = '0'

324
325
326
327
328
329
        mod = importlib.import_module(self.worker_module_name)
        worker_class = getattr(mod, self.worker_class_name)
        self.worker = worker_class(*args, **kwargs)

    def execute_method(self, method, *args, **kwargs):
        try:
330
331
            target = self if self.worker is None else self.worker
            executor = getattr(target, method)
332
333
334
335
336
337
338
339
340
341
            return executor(*args, **kwargs)
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
            msg = (f"Error executing method {method}. "
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e