"tests/vscode:/vscode.git/clone" did not exist on "84275504885ae5d4b3c63209f711706c8b758882"
worker_base.py 20 KB
Newer Older
1
import dataclasses
2
3
import importlib
import os
4
import time
5
from abc import ABC, abstractmethod
6
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
7

8
9
import torch

10
from vllm.config import ObservabilityConfig
11
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
from vllm.model_executor.layers.sampler import SamplerOutput
15
from vllm.platforms import current_platform
16
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
17
from vllm.utils import (enable_trace_function_call_for_thread,
18
                        update_environment_variables)
19
from vllm.worker.cache_engine import CacheEngine
20
21
22
from vllm.worker.model_runner_base import (BroadcastableModelInput,
                                           ModelRunnerBase,
                                           ModelRunnerInputBase)
23
24

logger = init_logger(__name__)
25
26
27
28


class WorkerBase(ABC):
    """Worker interface that allows vLLM to cleanly separate implementations for
29
30
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
31
32
    """

33
34
    model_input: Optional[ModelRunnerInputBase] = None

35
36
37
38
39
40
41
42
    @abstractmethod
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    @abstractmethod
43
    def determine_num_available_blocks(self) -> Tuple[int, int]:
44
45
46
47
48
49
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        The implementation may run profiling or other heuristics to determine
        the size of caches.

50
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
51
52
53
54
55
56
57
58
59
60
61
62
63
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

64
    @current_platform.inference_mode()
65
66
67
68
69
70
71
72
73
74
75
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
        while True:
            output = self.execute_model(execute_model_req=None)
            if output is None:
                return None

76
    @abstractmethod
77
    def execute_model(
78
79
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
80
    ) -> Optional[List[SamplerOutput]]:
81
82
83
        raise NotImplementedError

    @abstractmethod
84
    def get_cache_block_size_bytes(self) -> int:
85
86
87
88
89
90
91
92
93
94
95
96
97
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

98
99
100
101
    @abstractmethod
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

102
    @abstractmethod
103
    def list_loras(self) -> Set[int]:
104
        raise NotImplementedError
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    
    @property
    @abstractmethod
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
        """
        Gets the list of kv caches to pass to the worker's model runner. Each
        element in the list is a kv cache corresponding to a particular virtual
        engine (PP stream). Used by the default `execute_model`. If the worker's
        model runner does not follow the ModelRunnerBase interface, then inherit
        from WorkerBase instead.
        """
        raise NotImplementedError
    
    @property
    @abstractmethod
    def cache_engines(self) -> Optional[List[CacheEngine]]:
        raise NotImplementedError
122
123
124
125
126
127
128
129
130
131
132
133
134


class LoraNotSupportedWorkerBase(WorkerBase):
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

135
136
137
138
    def pin_lora(self, lora_id: int) -> bool:
        return ValueError(
            f"{type(self)} does not support LoRA")  # type: ignore

139
    def list_loras(self) -> Set[int]:
140
        raise ValueError(f"{type(self)} does not support LoRA")
141
142
143
144
145
146
147
148
    
    @property
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
        return None

    @property
    def cache_engines(self) -> Optional[List[CacheEngine]]:
        return None
149
150


151
152
153
154
155
156
157
158
159
160
@dataclasses.dataclass(frozen=True)
class WorkerInput:
    """Local inputs to each worker. May contain device-specific data. These
    fields should be broadcastable to other workers.
    """

    num_seq_groups: Optional[int] = None
    blocks_to_swap_in: Optional[torch.Tensor] = None
    blocks_to_swap_out: Optional[torch.Tensor] = None
    blocks_to_copy: Optional[torch.Tensor] = None
161
    virtual_engine: int = 0
162
    num_steps: int = 1
163

164
165
166
    # Optional slot mapping of kvcache that pending to be moved generated from draft model.
    kvcache_slot_to_be_moved: Optional[torch.Tensor] = None

167
168
169
170
171
172
173
174
175
176
177
178
179
180
    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["WorkerInput"],
        tensor_dict: Dict[str, Any],
    ) -> "WorkerInput":
        """
        Pop fields from the given tensor_dict and populate a new instance of
        WorkerInput.
        """
        return cls(
            num_seq_groups=tensor_dict.pop("num_seq_groups"),
            blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
            blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
            blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
181
            virtual_engine=tensor_dict["virtual_engine"],
182
            num_steps=tensor_dict.pop("num_steps"),
183
            kvcache_slot_to_be_moved=tensor_dict.pop("kvcache_slot_to_be_moved"),
184
185
186
187
188
189
190
191
192
193
194
195
        )

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        """
        Extract broadcastable fields.
        """
        tensor_dict = {
            "num_seq_groups": self.num_seq_groups,
            "blocks_to_swap_in": self.blocks_to_swap_in,
            "blocks_to_swap_out": self.blocks_to_swap_out,
            "blocks_to_copy": self.blocks_to_copy,
196
            "virtual_engine": self.virtual_engine,
197
            "num_steps": self.num_steps,
198
            "kvcache_slot_to_be_moved": self.kvcache_slot_to_be_moved
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        }

        return tensor_dict


class LocalOrDistributedWorkerBase(WorkerBase):
    """
    Partial implementation of WorkerBase that has a default `execute_model`
    definition to perform metadata transfer between workers when in distributed
    mode. Subclasses of this interface should use model runners that inherit
    from ModelRunnerBase, and should only need to implement worker-local logic.
    If custom control plane logic is needed to transfer metadata, or if the
    model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
    """
    is_driver_worker: bool
    model_runner: ModelRunnerBase
215
    observability_config: Optional[ObservabilityConfig] = None
216
217
218
219
220
221
222
223
224
225
226
227
228
229

    @property
    @abstractmethod
    def do_metadata_broadcast(self) -> bool:
        """
        Used by the default `execute_model` to check whether broadcast is
        needed to transfer request inputs from the driver worker to other
        workers in the TP group. If WorkerBase subclass only supports
        single-worker execution, then this method should return False.
        """
        raise NotImplementedError

    @property
    @abstractmethod
230
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
231
        """
232
233
234
235
236
        Gets the list of kv caches to pass to the worker's model runner. Each
        element in the list is a kv cache corresponding to a particular virtual
        engine (PP stream). Used by the default `execute_model`. If the worker's
        model runner does not follow the ModelRunnerBase interface, then inherit
        from WorkerBase instead.
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        """
        Prepare the inputs to WorkerBase.execute_worker from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @abstractmethod
    def execute_worker(self, worker_input: WorkerInput) -> None:
        """
        Process an execution request.
        """
        raise NotImplementedError

257
    def _get_worker_input_from_broadcast(
258
259
260
        self
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
261
262
263
264
265
266
267
268
269
270
271
272
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)
        if not broadcast_data:
            return None

        worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
        model_input = (
            self.model_runner.make_model_input_from_broadcasted_tensor_dict(
                broadcast_data))

273
274
275
        kwargs = extract_previous_hidden_states(broadcast_data)

        return model_input, worker_input, kwargs
276
277
278

    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
279
    ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
280
281
282
283
284
        """ Get the driver input and broadcast it to other workers.  """
        assert self.is_driver_worker

        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
285
286
287
288
289
290
291
        
        # set tree_attn_masks and position ids to seq_group_metadata_list
        if execute_model_req.tree_attn_masks is not None:
            for i, seq_group_metadata in enumerate(execute_model_req.seq_group_metadata_list):
                seq_group_metadata.set_tree_style_args(tree_attn_masks=execute_model_req.tree_attn_masks[i], 
                                                       tree_position_ids=execute_model_req.tree_position_ids[i])

292
293
294
295
296
297
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.virtual_engine,
                execute_model_req.finished_requests_ids))

298
299
        kwargs = extract_previous_hidden_states(execute_model_req)

300
301
302
        if self.do_metadata_broadcast:
            broadcast_data = worker_input.as_broadcastable_tensor_dict()
            broadcast_data.update(model_input.as_broadcastable_tensor_dict())
303
            broadcast_data.update(kwargs)
304
305
            broadcast_tensor_dict(broadcast_data, src=0)

306
        if execute_model_req.async_callback:
307
308
            model_input = dataclasses.replace(  # type: ignore
                model_input,
309
                async_callback=execute_model_req.async_callback)
310

311
        return model_input, worker_input, kwargs
312
313

    def prepare_input(
314
315
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
316
317
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
318
319
320
        """
        Prepare the inputs to ModelRunner and workers.
        """
321
322
323
324
325
326
327
328
329
330
        if self.is_driver_worker:
            if execute_model_req is None:
                if self.do_metadata_broadcast:
                    # This signals that there's no more requests to process for
                    # now. All workers are running infinite loop with
                    # broadcast_tensor_dict, and it stops the loop when the
                    # driver broadcasts an empty input. Send an empty input to
                    # notify all other workers to stop their execution loop.
                    broadcast_tensor_dict({}, src=0)
                return None
331
            return self._get_driver_input_and_broadcast(execute_model_req)
332
        else:
333
334
335
336
            return self._get_worker_input_from_broadcast()

    def execute_model(
        self,
337
        execute_model_req: Optional[ExecuteModelRequest] = None,
338
339
340
341
342
343
344
345
    ) -> Optional[List[SamplerOutput]]:
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
        start_time = time.perf_counter()

        inputs = self.prepare_input(execute_model_req)
        if inputs is None:
            return None
346

347
        model_input, worker_input, kwargs = inputs
348
        num_steps = worker_input.num_steps
349

350
351
        self.model_input = model_input

352
353
354
355
356
357
        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

358
        intermediate_tensors = None
359
        orig_model_execute_time = 0.0
360
361
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
362
363
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))
364
365
366
367
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                orig_model_execute_time = intermediate_tensors.tensors.get(
                    "model_execute_time", torch.tensor(0)).item()
368
369

        output = self.model_runner.execute_model(
370
371
372
373
374
375
376
377
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            num_steps=num_steps,
            **kwargs,
        )

378
        model_execute_time = time.perf_counter() - start_time
379
        if not get_pp_group().is_last_rank:
380
            # output is IntermediateTensors
381
382
383
384
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                output.tensors["model_execute_time"] = torch.tensor(
                    model_execute_time + orig_model_execute_time)
385
386
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
387
            return [None]
388
389
390
391
392
393
        if (self.observability_config is not None
                and self.observability_config.collect_model_execute_time
                and output is not None):
            for o in output:
                o.model_execute_time = (orig_model_execute_time +
                                        model_execute_time)
394

395
        # output is List[SamplerOutput]
396
        return output
397

398
    def _execute_model_spmd(
399
400
401
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    ) -> Optional[List[SamplerOutput]]:
        """
        Execute model in Single Program Multiple Data (SPMD) fashion.
        All workers take the same request, prepare the input and
        execute the model.
        """
        assert execute_model_req is not None, (
            "_execute_model_spmd() requires each worker to take in an "
            "ExecuteModelRequest")
        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list))

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

423
424
        kwargs = extract_previous_hidden_states(execute_model_req)

425
        return self.model_runner.execute_model(
426
427
428
429
430
431
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            **kwargs,
        )
432

433

434
435
436
437
438
439
class WorkerWrapperBase:
    """
    The whole point of this class is to lazily initialize the worker.
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
440
441
442
443
444

    If worker_class_fn is specified, it will be executed to get the worker
    class.
    Otherwise, the worker class will be obtained by dynamically importing it
    using worker_module_name and worker_class_name.
445
446
    """

447
448
449
450
451
452
453
    def __init__(
        self,
        worker_module_name: str,
        worker_class_name: str,
        trust_remote_code: bool = False,
        worker_class_fn: Optional[Callable[[],
                                           Type[WorkerBase]]] = None) -> None:
454
455
        self.worker_module_name = worker_module_name
        self.worker_class_name = worker_class_name
456
        self.worker_class_fn = worker_class_fn
457
        self.worker: Optional[WorkerBase] = None
458
459
460
461
        if trust_remote_code:
            # note: lazy import to avoid importing torch before initializing
            from vllm.utils import init_cached_hf_modules
            init_cached_hf_modules()
462

463
464
    @staticmethod
    def update_environment_variables(envs: Dict[str, str]) -> None:
465
466
467
468
469
470
471
472
473
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

    def init_worker(self, *args, **kwargs):
        """
474
        Here we inject some common logic before initializing the worker.
475
476
        Arguments are passed to the worker class constructor.
        """
477
        enable_trace_function_call_for_thread()
478

479
480
481
        # see https://github.com/NVIDIA/nccl/issues/1234
        os.environ['NCCL_CUMEM_ENABLE'] = '0'

482
483
484
        from vllm.plugins import load_general_plugins
        load_general_plugins()

485
486
487
488
489
        if self.worker_class_fn:
            worker_class = self.worker_class_fn()
        else:
            mod = importlib.import_module(self.worker_module_name)
            worker_class = getattr(mod, self.worker_class_name)
490

491
        self.worker = worker_class(*args, **kwargs)
492
        assert self.worker is not None
493
494
495

    def execute_method(self, method, *args, **kwargs):
        try:
496
497
            target = self if self.worker is None else self.worker
            executor = getattr(target, method)
498
499
500
501
502
503
504
505
506
507
            return executor(*args, **kwargs)
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
            msg = (f"Error executing method {method}. "
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527


def extract_previous_hidden_states(
        data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
            Dict[str, torch.Tensor]:
    """If data contains previous_hidden_states, extract it. This returns a dict
    which can be used directly as additional kwargs in any following 
    execute_model calls. This is used in draft models like EAGLE."""
    output = {}

    # When called from non-driver worker, data is dict but when called from
    # driver worker, data is ExecuteModelRequest.
    if isinstance(data, dict):
        if "previous_hidden_states" in data:
            output["previous_hidden_states"] = data["previous_hidden_states"]
    elif data.previous_hidden_states is not None:
        output["previous_hidden_states"] = data.previous_hidden_states\
            .hidden_states

    return output