sequence.py 2.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Sequence and its related classes."""
4

5
from dataclasses import dataclass
6
from typing import TYPE_CHECKING, Any
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
import torch

10
if TYPE_CHECKING:
11
    from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
12
13
else:
    KVConnectorOutput = Any
14

15

16
17
18
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
19
20
21
    """For all pipeline stages except the last, we need to return the hidden
    states and residuals to be sent to the next stage. This data structure
    contains the hidden states and residuals for a request.
22

23
    Each stage also needs to handle its own kv_connector_output.
24
25
    """

26
    tensors: dict[str, torch.Tensor]
27
    kv_connector_output: KVConnectorOutput | None
28

29
30
31
32
33
    def __init__(
        self,
        tensors: dict[str, torch.Tensor],
        kv_connector_output: KVConnectorOutput | None = None,
    ) -> None:
34
35
36
37
38
        # manually define this function, so that
        # Dynamo knows `IntermediateTensors()` comes from this file.
        # Otherwise, dataclass will generate this function by evaluating
        # a string, and we will lose the information about the source file.
        self.tensors = tensors
39
        self.kv_connector_output = kv_connector_output
40

41
    def __getitem__(self, key: str | slice):
42
43
44
45
46
        if isinstance(key, str):
            return self.tensors[key]
        elif isinstance(key, slice):
            return self.__class__({k: v[key] for k, v in self.tensors.items()})

47
    def __setitem__(self, key: str, value: torch.Tensor):
48
49
        self.tensors[key] = value

50
51
52
    def items(self):
        return self.tensors.items()

53
54
55
56
    def __len__(self):
        return len(self.tensors)

    def __eq__(self, other: object):
57
58
59
60
        if not isinstance(other, self.__class__):
            return False
        if self.tensors.keys() != other.tensors.keys():
            return False
61
        return all(torch.equal(self.tensors[k], other.tensors[k]) for k in self.tensors)
62
63
64

    def __repr__(self) -> str:
        return f"IntermediateTensors(tensors={self.tensors})"
65
66
67
68
69
70
71
72
73

    @staticmethod
    def empty_like(
        intermediate_tensors: "IntermediateTensors",
    ) -> "IntermediateTensors":
        tensors = {
            k: torch.empty_like(v) for k, v in intermediate_tensors.tensors.items()
        }
        return IntermediateTensors(tensors)