worker_base.py 24.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import dataclasses
4
import os
5
import time
6
from abc import ABC, abstractmethod
7
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
8

9
import cloudpickle
10
import torch
11
import torch.nn as nn
12

13
14
from vllm.config import (ObservabilityConfig, VllmConfig,
                         set_current_vllm_config)
15
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
16
from vllm.logger import init_logger
17
from vllm.lora.request import LoRARequest
18
19
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
20
from vllm.utils import (enable_trace_function_call_for_thread,
21
22
                        resolve_obj_by_qualname, run_method,
                        update_environment_variables)
zhuwenwen's avatar
zhuwenwen committed
23
from vllm.worker.cache_engine import CacheEngine
24
25
26
from vllm.worker.model_runner_base import (BroadcastableModelInput,
                                           ModelRunnerBase,
                                           ModelRunnerInputBase)
27
28

logger = init_logger(__name__)
29
30
31
32


class WorkerBase(ABC):
    """Worker interface that allows vLLM to cleanly separate implementations for
33
34
    different hardware. Also abstracts control plane communication, e.g., to
    communicate request metadata to other workers.
35
36
    """

37
    model_input: Optional[ModelRunnerInputBase] = None
38
    tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    def __init__(
        self,
        vllm_config: VllmConfig,
    ) -> None:
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.device_config = vllm_config.device_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
55
        self.kv_transfer_config = vllm_config.kv_transfer_config
56
        self.compilation_config = vllm_config.compilation_config
57
58
        from vllm.platforms import current_platform
        self.current_platform = current_platform
59

zhuwenwen's avatar
zhuwenwen committed
60

61
62
63
64
65
66
67
68
    @abstractmethod
    def init_device(self) -> None:
        """Initialize device state, such as loading the model or other on-device
        memory allocations.
        """
        raise NotImplementedError

    @abstractmethod
69
    def determine_num_available_blocks(self) -> Tuple[int, int]:
70
71
72
73
74
75
        """Determine the number of available blocks for the GPU KV cache and
        swappable CPU KV cache.

        The implementation may run profiling or other heuristics to determine
        the size of caches.

76
        Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
77
78
79
80
81
82
83
84
85
86
87
88
89
        are blocks that are "active" on the device and can be appended to.
        num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
        appended to.
        """
        raise NotImplementedError

    @abstractmethod
    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        """Initialize the KV cache with the given size in blocks.
        """
        raise NotImplementedError

90
91
92
93
94
95
    def start_worker_execution_loop(self) -> None:
        """Execute model loop in parallel worker.

        You can stop the loop by executing a driver worker with an empty output.
        See `stop_remote_worker_execution_loop` for more details.
        """
96
97
98
99
100
        with self.current_platform.inference_mode():
            while True:
                output = self.execute_model(execute_model_req=None)
                if output is None:
                    return None
101

102
103
104
    @abstractmethod
    def get_model(self) -> nn.Module:
        raise NotImplementedError
105

106
    @abstractmethod
107
    def execute_model(
108
109
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
110
    ) -> Optional[List[SamplerOutput]]:
111
112
113
        raise NotImplementedError

    @abstractmethod
114
    def get_cache_block_size_bytes(self) -> int:
115
116
117
118
119
120
121
122
123
124
125
126
127
        """Return the size of a single cache block, in bytes. Used in
        speculative decoding.
        """
        raise NotImplementedError

    @abstractmethod
    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise NotImplementedError

    @abstractmethod
    def remove_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

128
129
130
131
    @abstractmethod
    def pin_lora(self, lora_id: int) -> bool:
        raise NotImplementedError

132
    @abstractmethod
133
    def list_loras(self) -> Set[int]:
134
        raise NotImplementedError
135
136
137
138
139
    
    @property
    @abstractmethod
    def cache_engines(self) -> Optional[List[CacheEngine]]:
        raise NotImplementedError
140
141


142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class DelegateWorkerBase(WorkerBase):
    """
    A class that delegates all methods to another WorkerBase instance. This is
    useful for creating a WorkerBase that wraps another WorkerBase instance,
    e.g. speculative decoding.
    """
    worker: WorkerBase

    def __init__(
        self,
        *args,
        **kwargs,
    ) -> None:
        vllm_config: VllmConfig = kwargs.get("vllm_config")
        cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
        self.worker = cls(*args, **kwargs)

    def init_device(self) -> None:
        self.worker.init_device()

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        return self.worker.determine_num_available_blocks()

    def initialize_cache(self, num_gpu_blocks: int,
                         num_cpu_blocks: int) -> None:
        self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

169
170
171
    def get_model(self) -> nn.Module:
        return self.worker.get_model()

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def execute_model(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> Optional[List[SamplerOutput]]:
        return self.worker.execute_model(execute_model_req)

    def get_cache_block_size_bytes(self) -> int:
        return self.worker.get_cache_block_size_bytes()

    def add_lora(self, lora_request: LoRARequest) -> bool:
        return self.worker.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        return self.worker.remove_lora(lora_id)

    def pin_lora(self, lora_id: int) -> bool:
        return self.worker.pin_lora(lora_id)

    def list_loras(self) -> Set[int]:
        return self.worker.list_loras()

    def __getattr__(self, attr):
        return getattr(self.worker, attr)


197
198
199
200
201
202
203
204
205
206
207
class LoraNotSupportedWorkerBase(WorkerBase):
    """Partial implementation of WorkerBase that raises exceptions when LoRA
    methods are invoked.
    """

    def add_lora(self, lora_request: LoRARequest) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

    def remove_lora(self, lora_id: int) -> bool:
        raise ValueError(f"{type(self)} does not support LoRA")

208
209
210
211
    def pin_lora(self, lora_id: int) -> bool:
        return ValueError(
            f"{type(self)} does not support LoRA")  # type: ignore

212
    def list_loras(self) -> Set[int]:
213
        raise ValueError(f"{type(self)} does not support LoRA")
214
215
216
217

    @property
    def cache_engines(self) -> Optional[List[CacheEngine]]:
        return None
218
219


220
221
222
223
224
225
226
227
228
229
@dataclasses.dataclass(frozen=True)
class WorkerInput:
    """Local inputs to each worker. May contain device-specific data. These
    fields should be broadcastable to other workers.
    """

    num_seq_groups: Optional[int] = None
    blocks_to_swap_in: Optional[torch.Tensor] = None
    blocks_to_swap_out: Optional[torch.Tensor] = None
    blocks_to_copy: Optional[torch.Tensor] = None
230
    virtual_engine: int = 0
231
    num_steps: int = 1
232

233
234
235
    # Optional slot mapping of kvcache that pending to be moved generated from draft model.
    kvcache_slot_to_be_moved: Optional[torch.Tensor] = None

236
237
238
239
240
241
242
243
244
245
246
247
248
249
    @classmethod
    def from_broadcasted_tensor_dict(
        cls: Type["WorkerInput"],
        tensor_dict: Dict[str, Any],
    ) -> "WorkerInput":
        """
        Pop fields from the given tensor_dict and populate a new instance of
        WorkerInput.
        """
        return cls(
            num_seq_groups=tensor_dict.pop("num_seq_groups"),
            blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
            blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
            blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
250
            virtual_engine=tensor_dict["virtual_engine"],
251
            num_steps=tensor_dict.pop("num_steps"),
252
            kvcache_slot_to_be_moved=tensor_dict.pop("kvcache_slot_to_be_moved"),
253
254
255
256
257
258
259
260
261
262
263
264
        )

    def as_broadcastable_tensor_dict(
            self) -> Dict[str, Union[int, torch.Tensor]]:
        """
        Extract broadcastable fields.
        """
        tensor_dict = {
            "num_seq_groups": self.num_seq_groups,
            "blocks_to_swap_in": self.blocks_to_swap_in,
            "blocks_to_swap_out": self.blocks_to_swap_out,
            "blocks_to_copy": self.blocks_to_copy,
265
            "virtual_engine": self.virtual_engine,
266
            "num_steps": self.num_steps,
267
            "kvcache_slot_to_be_moved": self.kvcache_slot_to_be_moved
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        }

        return tensor_dict


class LocalOrDistributedWorkerBase(WorkerBase):
    """
    Partial implementation of WorkerBase that has a default `execute_model`
    definition to perform metadata transfer between workers when in distributed
    mode. Subclasses of this interface should use model runners that inherit
    from ModelRunnerBase, and should only need to implement worker-local logic.
    If custom control plane logic is needed to transfer metadata, or if the
    model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
    """
    is_driver_worker: bool
    model_runner: ModelRunnerBase
284
    observability_config: Optional[ObservabilityConfig] = None
285
286
287
288
289
290
291
292
293
294
295
296
297
298

    @property
    @abstractmethod
    def do_metadata_broadcast(self) -> bool:
        """
        Used by the default `execute_model` to check whether broadcast is
        needed to transfer request inputs from the driver worker to other
        workers in the TP group. If WorkerBase subclass only supports
        single-worker execution, then this method should return False.
        """
        raise NotImplementedError

    @property
    @abstractmethod
299
    def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
300
        """
301
302
303
304
305
        Gets the list of kv caches to pass to the worker's model runner. Each
        element in the list is a kv cache corresponding to a particular virtual
        engine (PP stream). Used by the default `execute_model`. If the worker's
        model runner does not follow the ModelRunnerBase interface, then inherit
        from WorkerBase instead.
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        """
        raise NotImplementedError

    @abstractmethod
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        """
        Prepare the inputs to WorkerBase.execute_worker from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @abstractmethod
    def execute_worker(self, worker_input: WorkerInput) -> None:
        """
        Process an execution request.
        """
        raise NotImplementedError

326
    def _get_worker_input_from_broadcast(
327
328
329
        self
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
330
331
332
333
334
335
336
337
338
339
340
341
        """ Get the worker input from the broadcasted tensor dict. """
        assert self.do_metadata_broadcast
        assert not self.is_driver_worker
        broadcast_data = broadcast_tensor_dict(src=0)
        if not broadcast_data:
            return None

        worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
        model_input = (
            self.model_runner.make_model_input_from_broadcasted_tensor_dict(
                broadcast_data))

342
343
344
        kwargs = extract_previous_hidden_states(broadcast_data)

        return model_input, worker_input, kwargs
345
346
347

    def _get_driver_input_and_broadcast(
        self, execute_model_req: ExecuteModelRequest
348
    ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
349
350
351
352
353
        """ Get the driver input and broadcast it to other workers.  """
        assert self.is_driver_worker

        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
354

355
356
357
358
359
360
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list,
                execute_model_req.virtual_engine,
                execute_model_req.finished_requests_ids))

361
362
363
364
365
366
367
368
369
370
371
        if self.tree_decoding and execute_model_req.tree_position_ids is not None and \
            execute_model_req.tree_attn_masks is not None:
            if hasattr(model_input, "input_positions") and \
                hasattr(model_input, "attn_metadata") and \
                    hasattr(model_input.attn_metadata, "tree_attention_masks_tensor"):
                attn_metadata = model_input.attn_metadata
                attn_metadata.tree_attention_masks_tensor = execute_model_req.tree_attn_masks.contiguous()
                model_input = dataclasses.replace(model_input,
                                    input_positions=execute_model_req.tree_position_ids.contiguous(),
                                    attn_metadata=attn_metadata)

372
373
        kwargs = extract_previous_hidden_states(execute_model_req)

374
375
376
        if self.do_metadata_broadcast:
            broadcast_data = worker_input.as_broadcastable_tensor_dict()
            broadcast_data.update(model_input.as_broadcastable_tensor_dict())
377
            broadcast_data.update(kwargs)
378
379
            broadcast_tensor_dict(broadcast_data, src=0)

380
        if execute_model_req.async_callback:
381
382
            model_input = dataclasses.replace(  # type: ignore
                model_input,
383
                async_callback=execute_model_req.async_callback)
384

385
        return model_input, worker_input, kwargs
386
387

    def prepare_input(
388
389
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
390
391
    ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
            str, torch.Tensor]]]:
392
393
394
        """
        Prepare the inputs to ModelRunner and workers.
        """
395
396
397
398
399
400
401
402
403
404
        if self.is_driver_worker:
            if execute_model_req is None:
                if self.do_metadata_broadcast:
                    # This signals that there's no more requests to process for
                    # now. All workers are running infinite loop with
                    # broadcast_tensor_dict, and it stops the loop when the
                    # driver broadcasts an empty input. Send an empty input to
                    # notify all other workers to stop their execution loop.
                    broadcast_tensor_dict({}, src=0)
                return None
405
            return self._get_driver_input_and_broadcast(execute_model_req)
406
        else:
407
408
            return self._get_worker_input_from_broadcast()

409
410
411
    def get_model(self) -> nn.Module:
        return self.model_runner.get_model()

412
413
    def execute_model(
        self,
414
        execute_model_req: Optional[ExecuteModelRequest] = None,
415
416
417
418
419
420
421
422
    ) -> Optional[List[SamplerOutput]]:
        """Executes at least one model step on the given sequences, unless no
        sequences are provided."""
        start_time = time.perf_counter()

        inputs = self.prepare_input(execute_model_req)
        if inputs is None:
            return None
423

424
        model_input, worker_input, kwargs = inputs
425
        num_steps = worker_input.num_steps
426

427
428
        self.model_input = model_input

429
430
431
432
433
434
        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

435
        intermediate_tensors = None
436
        orig_model_execute_time = 0.0
437
438
        if not get_pp_group().is_first_rank:
            intermediate_tensors = IntermediateTensors(
439
440
                get_pp_group().recv_tensor_dict(
                    all_gather_group=get_tp_group()))
441
442
443
444
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                orig_model_execute_time = intermediate_tensors.tensors.get(
                    "model_execute_time", torch.tensor(0)).item()
445
446

        output = self.model_runner.execute_model(
447
448
449
450
451
452
453
454
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            num_steps=num_steps,
            **kwargs,
        )

455
        model_execute_time = time.perf_counter() - start_time
456
        if not get_pp_group().is_last_rank:
457
            # output is IntermediateTensors
458
            assert isinstance(output, IntermediateTensors)
459
460
461
462
            if (self.observability_config is not None
                    and self.observability_config.collect_model_execute_time):
                output.tensors["model_execute_time"] = torch.tensor(
                    model_execute_time + orig_model_execute_time)
463
464
            get_pp_group().send_tensor_dict(output.tensors,
                                            all_gather_group=get_tp_group())
465
            return [None]
466
467
468
469
470
471
        if (self.observability_config is not None
                and self.observability_config.collect_model_execute_time
                and output is not None):
            for o in output:
                o.model_execute_time = (orig_model_execute_time +
                                        model_execute_time)
472

473
        # output is List[SamplerOutput]
474
        return output
475

476
    def _execute_model_spmd(
477
478
479
        self,
        execute_model_req: ExecuteModelRequest,
        intermediate_tensors: Optional[IntermediateTensors] = None
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    ) -> Optional[List[SamplerOutput]]:
        """
        Execute model in Single Program Multiple Data (SPMD) fashion.
        All workers take the same request, prepare the input and
        execute the model.
        """
        assert execute_model_req is not None, (
            "_execute_model_spmd() requires each worker to take in an "
            "ExecuteModelRequest")
        worker_input: WorkerInput = self.prepare_worker_input(
            execute_model_req=execute_model_req)
        model_input: ModelRunnerInputBase = (
            self.model_runner.prepare_model_input(
                execute_model_req.seq_group_metadata_list))

        self.execute_worker(worker_input)

        # If there is no input, we don't need to execute the model.
        if worker_input.num_seq_groups == 0:
            return []

501
502
        kwargs = extract_previous_hidden_states(execute_model_req)

503
        return self.model_runner.execute_model(
504
505
506
507
508
509
            model_input=model_input,
            kv_caches=self.kv_cache[worker_input.virtual_engine]
            if self.kv_cache is not None else None,
            intermediate_tensors=intermediate_tensors,
            **kwargs,
        )
510

511

512
513
class WorkerWrapperBase:
    """
514
515
    This class represents one process in an executor/engine. It is responsible
    for lazily initializing the worker and handling the worker's lifecycle.
516
517
518
519
520
    We first instantiate the WorkerWrapper, which remembers the worker module
    and class name. Then, when we call `update_environment_variables`, and the
    real initialization happens in `init_worker`.
    """

521
522
    def __init__(
        self,
523
        vllm_config: VllmConfig,
524
        rpc_rank: int = 0,
525
    ) -> None:
526
527
528
529
530
531
532
533
534
535
536
        """
        Initialize the worker wrapper with the given vllm_config and rpc_rank.
        Note: rpc_rank is the rank of the worker in the executor. In most cases,
        it is also the rank of the worker in the distributed group. However,
        when multiple executors work together, they can be different.
        e.g. in the case of SPMD-style offline inference with TP=2,
        users can launch 2 engines/executors, each with only 1 worker.
        All workers have rpc_rank=0, but they have different ranks in the TP
        group.
        """
        self.rpc_rank = rpc_rank
537
        self.worker: Optional[WorkerBase] = None
538
539
540
541
        # do not store this `vllm_config`, `init_worker` will set the final
        # one. TODO: investigate if we can remove this field in
        # `WorkerWrapperBase`, `init_cached_hf_modules` should be
        # unnecessary now.
542
543
544
545
546
547
548
549
550
551
        if vllm_config.model_config is not None:
            # it can be None in tests
            trust_remote_code = vllm_config.model_config.trust_remote_code
            if trust_remote_code:
                # note: lazy import to avoid importing torch before initializing
                from vllm.utils import init_cached_hf_modules
                init_cached_hf_modules()

    def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
        """
552
        Adjust the rpc_rank based on the given mapping.
553
        It is only used during the initialization of the executor,
554
        to adjust the rpc_rank of workers after we create all workers.
555
        """
556
557
        if self.rpc_rank in rank_mapping:
            self.rpc_rank = rank_mapping[self.rpc_rank]
558

559
560
    def update_environment_variables(self, envs_list: List[Dict[str,
                                                                str]]) -> None:
561
        envs = envs_list[self.rpc_rank]
562
563
564
565
566
567
568
        key = 'CUDA_VISIBLE_DEVICES'
        if key in envs and key in os.environ:
            # overwriting CUDA_VISIBLE_DEVICES is desired behavior
            # suppress the warning in `update_environment_variables`
            del os.environ[key]
        update_environment_variables(envs)

569
    def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
570
        """
571
        Here we inject some common logic before initializing the worker.
572
573
        Arguments are passed to the worker class constructor.
        """
574
        kwargs = all_kwargs[self.rpc_rank]
575
576
577
        self.vllm_config = kwargs.get("vllm_config", None)
        assert self.vllm_config is not None, (
            "vllm_config is required to initialize the worker")
578
        enable_trace_function_call_for_thread(self.vllm_config)
579

580
581
582
        from vllm.plugins import load_general_plugins
        load_general_plugins()

583
584
585
586
587
588
589
590
        if isinstance(self.vllm_config.parallel_config.worker_cls, str):
            worker_class = resolve_obj_by_qualname(
                self.vllm_config.parallel_config.worker_cls)
        else:
            assert isinstance(self.vllm_config.parallel_config.worker_cls,
                              bytes)
            worker_class = cloudpickle.loads(
                self.vllm_config.parallel_config.worker_cls)
591
592
593
594
        with set_current_vllm_config(self.vllm_config):
            # To make vLLM config available during worker initialization
            self.worker = worker_class(**kwargs)
            assert self.worker is not None
595

596
    def execute_method(self, method: Union[str, bytes], *args, **kwargs):
597
        try:
598
            target = self if self.worker is None else self.worker
599
            return run_method(target, method, args, kwargs)
600
601
602
603
604
        except Exception as e:
            # if the driver worker also execute methods,
            # exceptions in the rest worker may cause deadlock in rpc like ray
            # see https://github.com/vllm-project/vllm/issues/3455
            # print the error and inform the user to solve the error
605
            msg = (f"Error executing method {method!r}. "
606
607
608
                   "This might cause deadlock in distributed execution.")
            logger.exception(msg)
            raise e
609

610
611
612
    def __getattr__(self, attr):
        return getattr(self.worker, attr)

613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

def extract_previous_hidden_states(
        data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
            Dict[str, torch.Tensor]:
    """If data contains previous_hidden_states, extract it. This returns a dict
    which can be used directly as additional kwargs in any following 
    execute_model calls. This is used in draft models like EAGLE."""
    output = {}

    # When called from non-driver worker, data is dict but when called from
    # driver worker, data is ExecuteModelRequest.
    if isinstance(data, dict):
        if "previous_hidden_states" in data:
            output["previous_hidden_states"] = data["previous_hidden_states"]
    elif data.previous_hidden_states is not None:
        output["previous_hidden_states"] = data.previous_hidden_states\
            .hidden_states

    return output