worker_base.py 18.3 KB
Newer Older
1
import dataclasses
2
3
import importlib
import os
4
import time
5
from abc import ABC, abstractmethod
6
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
7

8
9
import torch

10
from vllm.config import ObservabilityConfig
11
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.platforms import current_platform
15
16
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
                           SamplerOutput)
17
from vllm.utils import (enable_trace_function_call_for_thread,
18
                        update_environment_variables)
19
20
21
from vllm.worker.model_runner_base import (BroadcastableModelInput,
                                           ModelRunnerBase,
                                           ModelRunnerInputBase)
22
23

logger = init_logger(__name__)
24
25
26
27


class WorkerBase(ABC):
    """Worker interface that allows vLLM to cleanly separate implementations for
28
29
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
30
31
32
33
34
35
36
37
38
39
    """

    @abstractmethod
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    @abstractmethod
40
    def determine_num_available_blocks(self) -> Tuple[int, int]:
41
42
43
44
45
46
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        The implementation may run profiling or other heuristics to determine
        the size of caches.

47
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
48
49
50
51
52
53
54
55
56
57
58
59
60
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

61
    @current_platform.inference_mode()
62
63
64
65
66
67
68
69
70
71
72
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
        while True:
            output = self.execute_model(execute_model_req=None)
            if output is None:
                return None

73
    @abstractmethod
74
    def execute_model(
75
76
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
77
    ) -> Optional[List[SamplerOutput]]:
78
79
80
        raise NotImplementedError

    @abstractmethod
81
    def get_cache_block_size_bytes(self) -> int:
82
83
84
85
86
87
88
89
90
91
92
93
94
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

95
96
97
98
    @abstractmethod
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

99
    @abstractmethod
100
    def list_loras(self) -> Set[int]:
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        raise NotImplementedError


class LoraNotSupportedWorkerBase(WorkerBase):
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

115
116
117
118
    def pin_lora(self, lora_id: int) -> bool:
        return ValueError(
            f"{type(self)} does not support LoRA")  # type: ignore

119
    def list_loras(self) -> Set[int]:
120
        raise ValueError(f"{type(self)} does not support LoRA")
121
122


123
124
125
126
127
128
129
130
131
132
@dataclasses.dataclass(frozen=True)
class WorkerInput:
    """Local inputs to each worker. May contain device-specific data. These
    fields should be broadcastable to other workers.
    """

    num_seq_groups: Optional[int] = None
    blocks_to_swap_in: Optional[torch.Tensor] = None
    blocks_to_swap_out: Optional[torch.Tensor] = None
    blocks_to_copy: Optional[torch.Tensor] = None
133
    virtual_engine: int = 0
134
    num_steps: int = 1
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["WorkerInput"],
        tensor_dict: Dict[str, Any],
    ) -> "WorkerInput":
        """
        Pop fields from the given tensor_dict and populate a new instance of
        WorkerInput.
        """
        return cls(
            num_seq_groups=tensor_dict.pop("num_seq_groups"),
            blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
            blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
            blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
150
            virtual_engine=tensor_dict["virtual_engine"],
151
            num_steps=tensor_dict.pop("num_steps"),
152
153
154
155
156
157
158
159
160
161
162
163
        )

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        """
        Extract broadcastable fields.
        """
        tensor_dict = {
            "num_seq_groups": self.num_seq_groups,
            "blocks_to_swap_in": self.blocks_to_swap_in,
            "blocks_to_swap_out": self.blocks_to_swap_out,
            "blocks_to_copy": self.blocks_to_copy,
164
            "virtual_engine": self.virtual_engine,
165
            "num_steps": self.num_steps,
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
        }

        return tensor_dict


class LocalOrDistributedWorkerBase(WorkerBase):
    """
    Partial implementation of WorkerBase that has a default `execute_model`
    definition to perform metadata transfer between workers when in distributed
    mode. Subclasses of this interface should use model runners that inherit
    from ModelRunnerBase, and should only need to implement worker-local logic.
    If custom control plane logic is needed to transfer metadata, or if the
    model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
    """
    is_driver_worker: bool
    model_runner: ModelRunnerBase
182
    observability_config: Optional[ObservabilityConfig] = None
183
184
185
186
187
188
189
190
191
192
193
194
195
196

    @property
    @abstractmethod
    def do_metadata_broadcast(self) -> bool:
        """
        Used by the default `execute_model` to check whether broadcast is
        needed to transfer request inputs from the driver worker to other
        workers in the TP group. If WorkerBase subclass only supports
        single-worker execution, then this method should return False.
        """
        raise NotImplementedError

    @property
    @abstractmethod
197
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
198
        """
199
200
201
202
203
        Gets the list of kv caches to pass to the worker's model runner. Each
        element in the list is a kv cache corresponding to a particular virtual
        engine (PP stream). Used by the default `execute_model`. If the worker's
        model runner does not follow the ModelRunnerBase interface, then inherit
        from WorkerBase instead.
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        """
        Prepare the inputs to WorkerBase.execute_worker from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @abstractmethod
    def execute_worker(self, worker_input: WorkerInput) -> None:
        """
        Process an execution request.
        """
        raise NotImplementedError

224
    def _get_worker_input_from_broadcast(
225
226
227
        self
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
228
229
230
231
232
233
234
235
236
237
238
239
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)
        if not broadcast_data:
            return None

        worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
        model_input = (
            self.model_runner.make_model_input_from_broadcasted_tensor_dict(
                broadcast_data))

240
241
242
        kwargs = extract_previous_hidden_states(broadcast_data)

        return model_input, worker_input, kwargs
243
244
245

    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
246
    ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
247
248
249
250
251
252
253
254
255
256
257
        """ Get the driver input and broadcast it to other workers.  """
        assert self.is_driver_worker

        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.virtual_engine,
                execute_model_req.finished_requests_ids))

258
259
        kwargs = extract_previous_hidden_states(execute_model_req)

260
261
262
        if self.do_metadata_broadcast:
            broadcast_data = worker_input.as_broadcastable_tensor_dict()
            broadcast_data.update(model_input.as_broadcastable_tensor_dict())
263
            broadcast_data.update(kwargs)
264
265
            broadcast_tensor_dict(broadcast_data, src=0)

266
        if execute_model_req.async_callback:
267
268
            model_input = dataclasses.replace(  # type: ignore
                model_input,
269
                async_callback=execute_model_req.async_callback)
270

271
        return model_input, worker_input, kwargs
272
273

    def prepare_input(
274
275
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
276
277
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
278
279
280
        """
        Prepare the inputs to ModelRunner and workers.
        """
281
282
283
284
285
286
287
288
289
290
        if self.is_driver_worker:
            if execute_model_req is None:
                if self.do_metadata_broadcast:
                    # This signals that there's no more requests to process for
                    # now. All workers are running infinite loop with
                    # broadcast_tensor_dict, and it stops the loop when the
                    # driver broadcasts an empty input. Send an empty input to
                    # notify all other workers to stop their execution loop.
                    broadcast_tensor_dict({}, src=0)
                return None
291
            return self._get_driver_input_and_broadcast(execute_model_req)
292
        else:
293
294
295
296
            return self._get_worker_input_from_broadcast()

    def execute_model(
        self,
297
        execute_model_req: Optional[ExecuteModelRequest] = None,
298
299
300
301
302
303
304
305
    ) -> Optional[List[SamplerOutput]]:
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
        start_time = time.perf_counter()

        inputs = self.prepare_input(execute_model_req)
        if inputs is None:
            return None
306

307
        model_input, worker_input, kwargs = inputs
308
        num_steps = worker_input.num_steps
309
310
311
312
313
314
315

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

316
        intermediate_tensors = None
317
        orig_model_execute_time = 0.0
318
319
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
320
321
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))
322
323
324
325
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                orig_model_execute_time = intermediate_tensors.tensors.get(
                    "model_execute_time", torch.tensor(0)).item()
326
327

        output = self.model_runner.execute_model(
328
329
330
331
332
333
334
335
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            num_steps=num_steps,
            **kwargs,
        )

336
        model_execute_time = time.perf_counter() - start_time
337
        if not get_pp_group().is_last_rank:
338
            # output is IntermediateTensors
339
340
341
342
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                output.tensors["model_execute_time"] = torch.tensor(
                    model_execute_time + orig_model_execute_time)
343
344
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
345
            return [None]
346
347
348
349
350
351
        if (self.observability_config is not None
                and self.observability_config.collect_model_execute_time
                and output is not None):
            for o in output:
                o.model_execute_time = (orig_model_execute_time +
                                        model_execute_time)
352

353
        # output is List[SamplerOutput]
354
        return output
355

356
    def _execute_model_spmd(
357
358
359
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    ) -> Optional[List[SamplerOutput]]:
        """
        Execute model in Single Program Multiple Data (SPMD) fashion.
        All workers take the same request, prepare the input and
        execute the model.
        """
        assert execute_model_req is not None, (
            "_execute_model_spmd() requires each worker to take in an "
            "ExecuteModelRequest")
        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list))

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

381
382
        kwargs = extract_previous_hidden_states(execute_model_req)

383
        return self.model_runner.execute_model(
384
385
386
387
388
389
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            **kwargs,
        )
390

391

392
393
394
395
396
397
class WorkerWrapperBase:
    """
    The whole point of this class is to lazily initialize the worker.
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
398
399
400
401
402

    If worker_class_fn is specified, it will be executed to get the worker
    class.
    Otherwise, the worker class will be obtained by dynamically importing it
    using worker_module_name and worker_class_name.
403
404
    """

405
406
407
408
409
410
411
    def __init__(
        self,
        worker_module_name: str,
        worker_class_name: str,
        trust_remote_code: bool = False,
        worker_class_fn: Optional[Callable[[],
                                           Type[WorkerBase]]] = None) -> None:
412
413
        self.worker_module_name = worker_module_name
        self.worker_class_name = worker_class_name
414
        self.worker_class_fn = worker_class_fn
415
        self.worker: Optional[WorkerBase] = None
416
417
418
419
        if trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
420

421
422
    @staticmethod
    def update_environment_variables(envs: Dict[str, str]) -> None:
423
424
425
426
427
428
429
430
431
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

    def init_worker(self, *args, **kwargs):
        """
432
        Here we inject some common logic before initializing the worker.
433
434
        Arguments are passed to the worker class constructor.
        """
435
        enable_trace_function_call_for_thread()
436

437
438
439
        # see https://github.com/NVIDIA/nccl/issues/1234
        os.environ['NCCL_CUMEM_ENABLE'] = '0'

440
441
442
        from vllm.plugins import load_general_plugins
        load_general_plugins()

443
444
445
446
447
        if self.worker_class_fn:
            worker_class = self.worker_class_fn()
        else:
            mod = importlib.import_module(self.worker_module_name)
            worker_class = getattr(mod, self.worker_class_name)
448

449
        self.worker = worker_class(*args, **kwargs)
450
        assert self.worker is not None
451
452
453

    def execute_method(self, method, *args, **kwargs):
        try:
454
455
            target = self if self.worker is None else self.worker
            executor = getattr(target, method)
456
457
458
459
460
461
462
463
464
465
            return executor(*args, **kwargs)
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
            msg = (f"Error executing method {method}. "
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485


def extract_previous_hidden_states(
        data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
            Dict[str, torch.Tensor]:
    """If data contains previous_hidden_states, extract it. This returns a dict
    which can be used directly as additional kwargs in any following 
    execute_model calls. This is used in draft models like EAGLE."""
    output = {}

    # When called from non-driver worker, data is dict but when called from
    # driver worker, data is ExecuteModelRequest.
    if isinstance(data, dict):
        if "previous_hidden_states" in data:
            output["previous_hidden_states"] = data["previous_hidden_states"]
    elif data.previous_hidden_states is not None:
        output["previous_hidden_states"] = data.previous_hidden_states\
            .hidden_states

    return output