worker_base.py 25.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import os
6
import time
7
from abc import abstractmethod
8
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
9

10
import cloudpickle
11
import torch
12
import torch.nn as nn
13

14
15
from vllm.config import (ObservabilityConfig, VllmConfig,
                         set_current_vllm_config)
16
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
20
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
21
from vllm.utils import (enable_trace_function_call_for_thread,
22
                        resolve_obj_by_qualname, run_method,
23
24
                        update_environment_variables,
                        warn_for_unimplemented_methods)
25
26
27
from vllm.worker.model_runner_base import (BroadcastableModelInput,
                                           ModelRunnerBase,
                                           ModelRunnerInputBase)
28
29

logger = init_logger(__name__)
30
31


32
33
@warn_for_unimplemented_methods
class WorkerBase:
34
    """Worker interface that allows vLLM to cleanly separate implementations for
35
36
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
37
38
    """

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
54
        self.kv_transfer_config = vllm_config.kv_transfer_config
55
        self.compilation_config = vllm_config.compilation_config
56
57
        from vllm.platforms import current_platform
        self.current_platform = current_platform
58

59
60
61
62
63
64
65
66
67
68
69
70
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

71
72
73
74
75
76
77
78
79
80
81
82
83
    def get_model(self) -> nn.Module:
        raise NotImplementedError

    def load_model(self) -> None:
        """Load model onto target device."""
        raise NotImplementedError

    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> Optional[List[SamplerOutput]]:
        raise NotImplementedError

84
85
86
87
88
89
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
90
91
92
93
94
        with self.current_platform.inference_mode():
            while True:
                output = self.execute_model(execute_model_req=None)
                if output is None:
                    return None
95

96
97
98
    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.
99

100
101
102
103
104
105
106
107
        The implementation may run profiling or other heuristics to determine
        the size of caches.

        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
108
109
        raise NotImplementedError

110
    def get_cache_block_size_bytes(self) -> int:
111
112
113
114
115
116
117
118
119
120
121
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

122
123
124
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

125
    def list_loras(self) -> Set[int]:
126
127
        raise NotImplementedError

128
129
130
131
132
    @property
    def vocab_size(self) -> int:
        """Get vocabulary size from model configuration."""
        return self.model_config.get_vocab_size()

133

134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class DelegateWorkerBase(WorkerBase):
    """
    A class that delegates all methods to another WorkerBase instance. This is
    useful for creating a WorkerBase that wraps another WorkerBase instance,
    e.g. speculative decoding.
    """
    worker: WorkerBase

    def __init__(
        self,
        *args,
        **kwargs,
    ) -> None:
        vllm_config: VllmConfig = kwargs.get("vllm_config")
        cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
        self.worker = cls(*args, **kwargs)

    def init_device(self) -> None:
        self.worker.init_device()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        return self.worker.determine_num_available_blocks()

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

161
162
163
164
    def load_model(self) -> None:
        """Load model onto target device."""
        self.worker.load_model()

165
166
167
    def get_model(self) -> nn.Module:
        return self.worker.get_model()

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> Optional[List[SamplerOutput]]:
        return self.worker.execute_model(execute_model_req)

    def get_cache_block_size_bytes(self) -> int:
        return self.worker.get_cache_block_size_bytes()

    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.worker.remove_lora(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        return self.worker.pin_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.worker.list_loras()

    def __getattr__(self, attr):
        return getattr(self.worker, attr)


193
class LoRANotSupportedWorkerBase(WorkerBase):
194
195
196
197
198
199
200
201
202
203
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

204
    def pin_lora(self, lora_id: int) -> bool:
205
        raise ValueError(f"{type(self)} does not support LoRA")
206

207
    def list_loras(self) -> Set[int]:
208
        raise ValueError(f"{type(self)} does not support LoRA")
209
210


211
212
213
214
215
216
217
218
219
220
@dataclasses.dataclass(frozen=True)
class WorkerInput:
    """Local inputs to each worker. May contain device-specific data. These
    fields should be broadcastable to other workers.
    """

    num_seq_groups: Optional[int] = None
    blocks_to_swap_in: Optional[torch.Tensor] = None
    blocks_to_swap_out: Optional[torch.Tensor] = None
    blocks_to_copy: Optional[torch.Tensor] = None
221
    virtual_engine: int = 0
222
    num_steps: int = 1
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["WorkerInput"],
        tensor_dict: Dict[str, Any],
    ) -> "WorkerInput":
        """
        Pop fields from the given tensor_dict and populate a new instance of
        WorkerInput.
        """
        return cls(
            num_seq_groups=tensor_dict.pop("num_seq_groups"),
            blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
            blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
            blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
238
            virtual_engine=tensor_dict["virtual_engine"],
239
            num_steps=tensor_dict.pop("num_steps"),
240
241
242
243
244
245
246
247
248
249
250
251
        )

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        """
        Extract broadcastable fields.
        """
        tensor_dict = {
            "num_seq_groups": self.num_seq_groups,
            "blocks_to_swap_in": self.blocks_to_swap_in,
            "blocks_to_swap_out": self.blocks_to_swap_out,
            "blocks_to_copy": self.blocks_to_copy,
252
            "virtual_engine": self.virtual_engine,
253
            "num_steps": self.num_steps,
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        }

        return tensor_dict


class LocalOrDistributedWorkerBase(WorkerBase):
    """
    Partial implementation of WorkerBase that has a default `execute_model`
    definition to perform metadata transfer between workers when in distributed
    mode. Subclasses of this interface should use model runners that inherit
    from ModelRunnerBase, and should only need to implement worker-local logic.
    If custom control plane logic is needed to transfer metadata, or if the
    model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
    """
    is_driver_worker: bool
    model_runner: ModelRunnerBase
270
    observability_config: Optional[ObservabilityConfig] = None
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    @property
    @abstractmethod
    def do_metadata_broadcast(self) -> bool:
        """
        Used by the default `execute_model` to check whether broadcast is
        needed to transfer request inputs from the driver worker to other
        workers in the TP group. If WorkerBase subclass only supports
        single-worker execution, then this method should return False.
        """
        raise NotImplementedError

    @property
    @abstractmethod
285
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
286
        """
287
288
289
290
291
        Gets the list of kv caches to pass to the worker's model runner. Each
        element in the list is a kv cache corresponding to a particular virtual
        engine (PP stream). Used by the default `execute_model`. If the worker's
        model runner does not follow the ModelRunnerBase interface, then inherit
        from WorkerBase instead.
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        """
        Prepare the inputs to WorkerBase.execute_worker from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @abstractmethod
    def execute_worker(self, worker_input: WorkerInput) -> None:
        """
        Process an execution request.
        """
        raise NotImplementedError

312
    def _get_worker_input_from_broadcast(
313
314
315
        self
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
316
317
318
319
320
321
322
323
324
325
326
327
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)
        if not broadcast_data:
            return None

        worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
        model_input = (
            self.model_runner.make_model_input_from_broadcasted_tensor_dict(
                broadcast_data))

328
329
330
        kwargs = extract_previous_hidden_states(broadcast_data)

        return model_input, worker_input, kwargs
331
332
333

    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
334
    ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
335
336
337
338
339
340
341
342
343
344
345
        """ Get the driver input and broadcast it to other workers.  """
        assert self.is_driver_worker

        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.virtual_engine,
                execute_model_req.finished_requests_ids))

346
347
        kwargs = extract_previous_hidden_states(execute_model_req)

348
349
350
        if self.do_metadata_broadcast:
            broadcast_data = worker_input.as_broadcastable_tensor_dict()
            broadcast_data.update(model_input.as_broadcastable_tensor_dict())
351
            broadcast_data.update(kwargs)
352
353
            broadcast_tensor_dict(broadcast_data, src=0)

354
        if execute_model_req.async_callback:
355
356
            model_input = dataclasses.replace(  # type: ignore
                model_input,
357
                async_callback=execute_model_req.async_callback)
358

359
        return model_input, worker_input, kwargs
360
361

    def prepare_input(
362
363
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
364
365
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
366
367
368
        """
        Prepare the inputs to ModelRunner and workers.
        """
369
370
371
372
373
374
375
376
377
378
        if self.is_driver_worker:
            if execute_model_req is None:
                if self.do_metadata_broadcast:
                    # This signals that there's no more requests to process for
                    # now. All workers are running infinite loop with
                    # broadcast_tensor_dict, and it stops the loop when the
                    # driver broadcasts an empty input. Send an empty input to
                    # notify all other workers to stop their execution loop.
                    broadcast_tensor_dict({}, src=0)
                return None
379
            return self._get_driver_input_and_broadcast(execute_model_req)
380
        else:
381
382
            return self._get_worker_input_from_broadcast()

383
384
385
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

386
387
    def execute_model(
        self,
388
        execute_model_req: Optional[ExecuteModelRequest] = None,
389
390
391
392
393
394
395
396
    ) -> Optional[List[SamplerOutput]]:
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
        start_time = time.perf_counter()

        inputs = self.prepare_input(execute_model_req)
        if inputs is None:
            return None
397

398
        model_input, worker_input, kwargs = inputs
399
        num_steps = worker_input.num_steps
400
        if execute_model_req is not None and execute_model_req.spec_step_idx:
401
            kwargs["spec_step_idx"] = execute_model_req.spec_step_idx
402
403
404
405
406
407
408

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

409
        intermediate_tensors = None
410
        orig_model_execute_time = 0.0
411
412
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
413
414
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))
415
416
417
418
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                orig_model_execute_time = intermediate_tensors.tensors.get(
                    "model_execute_time", torch.tensor(0)).item()
419
420

        output = self.model_runner.execute_model(
421
422
423
424
425
426
427
428
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            num_steps=num_steps,
            **kwargs,
        )

429
        model_execute_time = time.perf_counter() - start_time
430
        if not get_pp_group().is_last_rank:
431
            # output is IntermediateTensors
432
            assert isinstance(output, IntermediateTensors)
433
434
435
436
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                output.tensors["model_execute_time"] = torch.tensor(
                    model_execute_time + orig_model_execute_time)
437
438
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
439
            return [None]
440
441
442
443
444
445
        if (self.observability_config is not None
                and self.observability_config.collect_model_execute_time
                and output is not None):
            for o in output:
                o.model_execute_time = (orig_model_execute_time +
                                        model_execute_time)
446

447
        # output is List[SamplerOutput]
448
        return output
449

450
    def _execute_model_spmd(
451
452
453
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    ) -> Optional[List[SamplerOutput]]:
        """
        Execute model in Single Program Multiple Data (SPMD) fashion.
        All workers take the same request, prepare the input and
        execute the model.
        """
        assert execute_model_req is not None, (
            "_execute_model_spmd() requires each worker to take in an "
            "ExecuteModelRequest")
        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list))

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

475
476
        kwargs = extract_previous_hidden_states(execute_model_req)

477
        return self.model_runner.execute_model(
478
479
480
481
482
483
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            **kwargs,
        )
484

485

486
487
class WorkerWrapperBase:
    """
488
489
    This class represents one process in an executor/engine. It is responsible
    for lazily initializing the worker and handling the worker's lifecycle.
490
491
492
493
494
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

495
496
    def __init__(
        self,
497
        vllm_config: VllmConfig,
498
        rpc_rank: int = 0,
499
    ) -> None:
500
501
502
503
504
505
506
507
508
509
510
        """
        Initialize the worker wrapper with the given vllm_config and rpc_rank.
        Note: rpc_rank is the rank of the worker in the executor. In most cases,
        it is also the rank of the worker in the distributed group. However,
        when multiple executors work together, they can be different.
        e.g. in the case of SPMD-style offline inference with TP=2,
        users can launch 2 engines/executors, each with only 1 worker.
        All workers have rpc_rank=0, but they have different ranks in the TP
        group.
        """
        self.rpc_rank = rpc_rank
511
        self.worker: Optional[WorkerBase] = None
512
        self.vllm_config: Optional[VllmConfig] = None
513
514
515
516
        # do not store this `vllm_config`, `init_worker` will set the final
        # one. TODO: investigate if we can remove this field in
        # `WorkerWrapperBase`, `init_cached_hf_modules` should be
        # unnecessary now.
517
518
519
520
521
522
523
524
525
526
        if vllm_config.model_config is not None:
            # it can be None in tests
            trust_remote_code = vllm_config.model_config.trust_remote_code
            if trust_remote_code:
                # note: lazy import to avoid importing torch before initializing
                from vllm.utils import init_cached_hf_modules
                init_cached_hf_modules()

    def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
        """
527
        Adjust the rpc_rank based on the given mapping.
528
        It is only used during the initialization of the executor,
529
        to adjust the rpc_rank of workers after we create all workers.
530
        """
531
532
        if self.rpc_rank in rank_mapping:
            self.rpc_rank = rank_mapping[self.rpc_rank]
533

534
535
    def update_environment_variables(self, envs_list: List[Dict[str,
                                                                str]]) -> None:
536
        envs = envs_list[self.rpc_rank]
537
538
539
540
541
542
543
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

544
    def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
545
        """
546
        Here we inject some common logic before initializing the worker.
547
548
        Arguments are passed to the worker class constructor.
        """
549
        kwargs = all_kwargs[self.rpc_rank]
550
551
552
        self.vllm_config = kwargs.get("vllm_config", None)
        assert self.vllm_config is not None, (
            "vllm_config is required to initialize the worker")
553
        enable_trace_function_call_for_thread(self.vllm_config)
554

555
556
557
        from vllm.plugins import load_general_plugins
        load_general_plugins()

558
559
560
561
        if isinstance(self.vllm_config.parallel_config.worker_cls, str):
            worker_class = resolve_obj_by_qualname(
                self.vllm_config.parallel_config.worker_cls)
        else:
562
563
564
565
566
567
            logger.warning(
                "passing worker_cls as a class object is strongly deprecated,"
                " as the serialization of class objects can be tricky and"
                " error-prone. To be safe, please keep the class in a separate"
                " module and pass the qualified name of the class as a string."
            )
568
569
570
571
            assert isinstance(self.vllm_config.parallel_config.worker_cls,
                              bytes)
            worker_class = cloudpickle.loads(
                self.vllm_config.parallel_config.worker_cls)
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        if self.vllm_config.parallel_config.worker_extension_cls:
            worker_extension_cls = resolve_obj_by_qualname(
                self.vllm_config.parallel_config.worker_extension_cls)
            extended_calls = []
            if worker_extension_cls not in worker_class.__bases__:
                # check any conflicts between worker and worker_extension_cls
                for attr in dir(worker_extension_cls):
                    if attr.startswith("__"):
                        continue
                    assert not hasattr(worker_class, attr), (
                        f"Worker class {worker_class} already has an attribute"
                        f" {attr}, which conflicts with the worker"
                        f" extension class {worker_extension_cls}.")
                    if callable(getattr(worker_extension_cls, attr)):
                        extended_calls.append(attr)
                # dynamically inherit the worker extension class
                worker_class.__bases__ = worker_class.__bases__ + (
                    worker_extension_cls, )
                logger.info(
                    "Injected %s into %s for extended collective_rpc calls %s",
                    worker_extension_cls, worker_class, extended_calls)
593
594
595
596
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during worker initialization
            self.worker = worker_class(**kwargs)
            assert self.worker is not None
597

598
599
    def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
        kv_cache_config = kv_cache_configs[self.rpc_rank]
600
601
        with set_current_vllm_config(self.vllm_config):
            self.worker.initialize_from_config(kv_cache_config)  # type: ignore
602

603
604
605
606
607
    def init_device(self):
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during device initialization
            self.worker.init_device()  # type: ignore

608
    def execute_method(self, method: Union[str, bytes], *args, **kwargs):
609
        try:
610
611
612
613
614
            # method resolution order:
            # if a method is defined in this class, it will be called directly.
            # otherwise, since we define `__getattr__` and redirect attribute
            # query to `self.worker`, the method will be called on the worker.
            return run_method(self, method, args, kwargs)
615
616
617
618
619
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
620
            msg = (f"Error executing method {method!r}. "
621
622
623
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e
624

625
626
627
    def __getattr__(self, attr):
        return getattr(self.worker, attr)

628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

def extract_previous_hidden_states(
        data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
            Dict[str, torch.Tensor]:
    """If data contains previous_hidden_states, extract it. This returns a dict
    which can be used directly as additional kwargs in any following 
    execute_model calls. This is used in draft models like EAGLE."""
    output = {}

    # When called from non-driver worker, data is dict but when called from
    # driver worker, data is ExecuteModelRequest.
    if isinstance(data, dict):
        if "previous_hidden_states" in data:
            output["previous_hidden_states"] = data["previous_hidden_states"]
    elif data.previous_hidden_states is not None:
        output["previous_hidden_states"] = data.previous_hidden_states\
            .hidden_states

    return output