worker_base.py 22.1 KB
Newer Older
1
import dataclasses
2
import os
3
import time
4
from abc import ABC, abstractmethod
5
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
6

7
import cloudpickle
8
9
import torch

10
from vllm.config import ObservabilityConfig, VllmConfig
11
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
15
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
16
from vllm.utils import (enable_trace_function_call_for_thread,
17
18
                        resolve_obj_by_qualname, run_method,
                        update_environment_variables)
19
20
21
from vllm.worker.model_runner_base import (BroadcastableModelInput,
                                           ModelRunnerBase,
                                           ModelRunnerInputBase)
22
23

logger = init_logger(__name__)
24
25
26
27


class WorkerBase(ABC):
    """Worker interface that allows vLLM to cleanly separate implementations for
28
29
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
30
31
    """

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
47
        self.kv_transfer_config = vllm_config.kv_transfer_config
48
        self.compilation_config = vllm_config.compilation_config
49
50
        from vllm.platforms import current_platform
        self.current_platform = current_platform
51

52
53
54
55
56
57
58
59
    @abstractmethod
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    @abstractmethod
60
    def determine_num_available_blocks(self) -> Tuple[int, int]:
61
62
63
64
65
66
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        The implementation may run profiling or other heuristics to determine
        the size of caches.

67
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
68
69
70
71
72
73
74
75
76
77
78
79
80
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

81
82
83
84
85
86
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
87
88
89
90
91
        with self.current_platform.inference_mode():
            while True:
                output = self.execute_model(execute_model_req=None)
                if output is None:
                    return None
92

93
    def execute_model(
94
95
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
96
    ) -> Optional[List[SamplerOutput]]:
97
98
99
        raise NotImplementedError

    @abstractmethod
100
    def get_cache_block_size_bytes(self) -> int:
101
102
103
104
105
106
107
108
109
110
111
112
113
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

114
115
116
117
    @abstractmethod
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

118
    @abstractmethod
119
    def list_loras(self) -> Set[int]:
120
121
122
        raise NotImplementedError


123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class DelegateWorkerBase(WorkerBase):
    """
    A class that delegates all methods to another WorkerBase instance. This is
    useful for creating a WorkerBase that wraps another WorkerBase instance,
    e.g. speculative decoding.
    """
    worker: WorkerBase

    def __init__(
        self,
        *args,
        **kwargs,
    ) -> None:
        vllm_config: VllmConfig = kwargs.get("vllm_config")
        cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
        self.worker = cls(*args, **kwargs)

    def init_device(self) -> None:
        self.worker.init_device()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        return self.worker.determine_num_available_blocks()

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> Optional[List[SamplerOutput]]:
        return self.worker.execute_model(execute_model_req)

    def get_cache_block_size_bytes(self) -> int:
        return self.worker.get_cache_block_size_bytes()

    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.worker.remove_lora(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        return self.worker.pin_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.worker.list_loras()

    def __getattr__(self, attr):
        return getattr(self.worker, attr)


175
176
177
178
179
180
181
182
183
184
185
class LoraNotSupportedWorkerBase(WorkerBase):
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

186
187
188
189
    def pin_lora(self, lora_id: int) -> bool:
        return ValueError(
            f"{type(self)} does not support LoRA")  # type: ignore

190
    def list_loras(self) -> Set[int]:
191
        raise ValueError(f"{type(self)} does not support LoRA")
192
193


194
195
196
197
198
199
200
201
202
203
@dataclasses.dataclass(frozen=True)
class WorkerInput:
    """Local inputs to each worker. May contain device-specific data. These
    fields should be broadcastable to other workers.
    """

    num_seq_groups: Optional[int] = None
    blocks_to_swap_in: Optional[torch.Tensor] = None
    blocks_to_swap_out: Optional[torch.Tensor] = None
    blocks_to_copy: Optional[torch.Tensor] = None
204
    virtual_engine: int = 0
205
    num_steps: int = 1
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["WorkerInput"],
        tensor_dict: Dict[str, Any],
    ) -> "WorkerInput":
        """
        Pop fields from the given tensor_dict and populate a new instance of
        WorkerInput.
        """
        return cls(
            num_seq_groups=tensor_dict.pop("num_seq_groups"),
            blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
            blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
            blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
221
            virtual_engine=tensor_dict["virtual_engine"],
222
            num_steps=tensor_dict.pop("num_steps"),
223
224
225
226
227
228
229
230
231
232
233
234
        )

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        """
        Extract broadcastable fields.
        """
        tensor_dict = {
            "num_seq_groups": self.num_seq_groups,
            "blocks_to_swap_in": self.blocks_to_swap_in,
            "blocks_to_swap_out": self.blocks_to_swap_out,
            "blocks_to_copy": self.blocks_to_copy,
235
            "virtual_engine": self.virtual_engine,
236
            "num_steps": self.num_steps,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        }

        return tensor_dict


class LocalOrDistributedWorkerBase(WorkerBase):
    """
    Partial implementation of WorkerBase that has a default `execute_model`
    definition to perform metadata transfer between workers when in distributed
    mode. Subclasses of this interface should use model runners that inherit
    from ModelRunnerBase, and should only need to implement worker-local logic.
    If custom control plane logic is needed to transfer metadata, or if the
    model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
    """
    is_driver_worker: bool
    model_runner: ModelRunnerBase
253
    observability_config: Optional[ObservabilityConfig] = None
254
255
256
257
258
259
260
261
262
263
264
265
266
267

    @property
    @abstractmethod
    def do_metadata_broadcast(self) -> bool:
        """
        Used by the default `execute_model` to check whether broadcast is
        needed to transfer request inputs from the driver worker to other
        workers in the TP group. If WorkerBase subclass only supports
        single-worker execution, then this method should return False.
        """
        raise NotImplementedError

    @property
    @abstractmethod
268
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
269
        """
270
271
272
273
274
        Gets the list of kv caches to pass to the worker's model runner. Each
        element in the list is a kv cache corresponding to a particular virtual
        engine (PP stream). Used by the default `execute_model`. If the worker's
        model runner does not follow the ModelRunnerBase interface, then inherit
        from WorkerBase instead.
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        """
        Prepare the inputs to WorkerBase.execute_worker from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @abstractmethod
    def execute_worker(self, worker_input: WorkerInput) -> None:
        """
        Process an execution request.
        """
        raise NotImplementedError

295
    def _get_worker_input_from_broadcast(
296
297
298
        self
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
299
300
301
302
303
304
305
306
307
308
309
310
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)
        if not broadcast_data:
            return None

        worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
        model_input = (
            self.model_runner.make_model_input_from_broadcasted_tensor_dict(
                broadcast_data))

311
312
313
        kwargs = extract_previous_hidden_states(broadcast_data)

        return model_input, worker_input, kwargs
314
315
316

    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
317
    ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
318
319
320
321
322
323
324
325
326
327
328
        """ Get the driver input and broadcast it to other workers.  """
        assert self.is_driver_worker

        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.virtual_engine,
                execute_model_req.finished_requests_ids))

329
330
        kwargs = extract_previous_hidden_states(execute_model_req)

331
332
333
        if self.do_metadata_broadcast:
            broadcast_data = worker_input.as_broadcastable_tensor_dict()
            broadcast_data.update(model_input.as_broadcastable_tensor_dict())
334
            broadcast_data.update(kwargs)
335
336
            broadcast_tensor_dict(broadcast_data, src=0)

337
        if execute_model_req.async_callback:
338
339
            model_input = dataclasses.replace(  # type: ignore
                model_input,
340
                async_callback=execute_model_req.async_callback)
341

342
        return model_input, worker_input, kwargs
343
344

    def prepare_input(
345
346
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
347
348
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
349
350
351
        """
        Prepare the inputs to ModelRunner and workers.
        """
352
353
354
355
356
357
358
359
360
361
        if self.is_driver_worker:
            if execute_model_req is None:
                if self.do_metadata_broadcast:
                    # This signals that there's no more requests to process for
                    # now. All workers are running infinite loop with
                    # broadcast_tensor_dict, and it stops the loop when the
                    # driver broadcasts an empty input. Send an empty input to
                    # notify all other workers to stop their execution loop.
                    broadcast_tensor_dict({}, src=0)
                return None
362
            return self._get_driver_input_and_broadcast(execute_model_req)
363
        else:
364
365
366
367
            return self._get_worker_input_from_broadcast()

    def execute_model(
        self,
368
        execute_model_req: Optional[ExecuteModelRequest] = None,
369
370
371
372
373
374
375
376
    ) -> Optional[List[SamplerOutput]]:
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
        start_time = time.perf_counter()

        inputs = self.prepare_input(execute_model_req)
        if inputs is None:
            return None
377

378
        model_input, worker_input, kwargs = inputs
379
        num_steps = worker_input.num_steps
380
381
382
383
384
385
386

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

387
        intermediate_tensors = None
388
        orig_model_execute_time = 0.0
389
390
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
391
392
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))
393
394
395
396
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                orig_model_execute_time = intermediate_tensors.tensors.get(
                    "model_execute_time", torch.tensor(0)).item()
397
398

        output = self.model_runner.execute_model(
399
400
401
402
403
404
405
406
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            num_steps=num_steps,
            **kwargs,
        )

407
        model_execute_time = time.perf_counter() - start_time
408
        if not get_pp_group().is_last_rank:
409
            # output is IntermediateTensors
410
            assert isinstance(output, IntermediateTensors)
411
412
413
414
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                output.tensors["model_execute_time"] = torch.tensor(
                    model_execute_time + orig_model_execute_time)
415
416
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
417
            return [None]
418
419
420
421
422
423
        if (self.observability_config is not None
                and self.observability_config.collect_model_execute_time
                and output is not None):
            for o in output:
                o.model_execute_time = (orig_model_execute_time +
                                        model_execute_time)
424

425
        # output is List[SamplerOutput]
426
        return output
427

428
    def _execute_model_spmd(
429
430
431
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
    ) -> Optional[List[SamplerOutput]]:
        """
        Execute model in Single Program Multiple Data (SPMD) fashion.
        All workers take the same request, prepare the input and
        execute the model.
        """
        assert execute_model_req is not None, (
            "_execute_model_spmd() requires each worker to take in an "
            "ExecuteModelRequest")
        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list))

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

453
454
        kwargs = extract_previous_hidden_states(execute_model_req)

455
        return self.model_runner.execute_model(
456
457
458
459
460
461
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            **kwargs,
        )
462

463

464
465
class WorkerWrapperBase:
    """
466
467
    This class represents one process in an executor/engine. It is responsible
    for lazily initializing the worker and handling the worker's lifecycle.
468
469
470
471
472
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

473
474
    def __init__(
        self,
475
        vllm_config: VllmConfig,
476
        rpc_rank: int = 0,
477
    ) -> None:
478
479
480
481
482
483
484
485
486
487
488
        """
        Initialize the worker wrapper with the given vllm_config and rpc_rank.
        Note: rpc_rank is the rank of the worker in the executor. In most cases,
        it is also the rank of the worker in the distributed group. However,
        when multiple executors work together, they can be different.
        e.g. in the case of SPMD-style offline inference with TP=2,
        users can launch 2 engines/executors, each with only 1 worker.
        All workers have rpc_rank=0, but they have different ranks in the TP
        group.
        """
        self.rpc_rank = rpc_rank
489
        self.vllm_config = vllm_config
490
        self.worker: Optional[WorkerBase] = None
491
492
493
494
495
496
497
498
499
500
        if vllm_config.model_config is not None:
            # it can be None in tests
            trust_remote_code = vllm_config.model_config.trust_remote_code
            if trust_remote_code:
                # note: lazy import to avoid importing torch before initializing
                from vllm.utils import init_cached_hf_modules
                init_cached_hf_modules()

    def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
        """
501
        Adjust the rpc_rank based on the given mapping.
502
        It is only used during the initialization of the executor,
503
        to adjust the rpc_rank of workers after we create all workers.
504
        """
505
506
        if self.rpc_rank in rank_mapping:
            self.rpc_rank = rank_mapping[self.rpc_rank]
507

508
509
    def update_environment_variables(self, envs_list: List[Dict[str,
                                                                str]]) -> None:
510
        envs = envs_list[self.rpc_rank]
511
512
513
514
515
516
517
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

518
    def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
519
        """
520
        Here we inject some common logic before initializing the worker.
521
522
        Arguments are passed to the worker class constructor.
        """
523
        kwargs = all_kwargs[self.rpc_rank]
524
        enable_trace_function_call_for_thread(self.vllm_config)
525

526
527
        from vllm import configure_as_vllm_process
        configure_as_vllm_process()
528

529
530
531
        from vllm.plugins import load_general_plugins
        load_general_plugins()

532
533
534
535
536
537
538
539
        if isinstance(self.vllm_config.parallel_config.worker_cls, str):
            worker_class = resolve_obj_by_qualname(
                self.vllm_config.parallel_config.worker_cls)
        else:
            assert isinstance(self.vllm_config.parallel_config.worker_cls,
                              bytes)
            worker_class = cloudpickle.loads(
                self.vllm_config.parallel_config.worker_cls)
540
        self.worker = worker_class(**kwargs)
541
        assert self.worker is not None
542

543
    def execute_method(self, method: Union[str, bytes], *args, **kwargs):
544
        try:
545
            target = self if self.worker is None else self.worker
546
            return run_method(target, method, args, kwargs)
547
548
549
550
551
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
552
            msg = (f"Error executing method {method!r}. "
553
554
555
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e
556

557
558
559
    def __getattr__(self, attr):
        return getattr(self.worker, attr)

560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

def extract_previous_hidden_states(
        data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
            Dict[str, torch.Tensor]:
    """If data contains previous_hidden_states, extract it. This returns a dict
    which can be used directly as additional kwargs in any following 
    execute_model calls. This is used in draft models like EAGLE."""
    output = {}

    # When called from non-driver worker, data is dict but when called from
    # driver worker, data is ExecuteModelRequest.
    if isinstance(data, dict):
        if "previous_hidden_states" in data:
            output["previous_hidden_states"] = data["previous_hidden_states"]
    elif data.previous_hidden_states is not None:
        output["previous_hidden_states"] = data.previous_hidden_states\
            .hidden_states

    return output