abstract.py 5.07 KB
Newer Older
1
2
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
3
from enum import Enum, auto
4
5
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
                    Tuple, Type, TypeVar)
6
7
8

import torch

9
10
11
if TYPE_CHECKING:
    from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase

12

13
14
15
16
17
18
class AttentionType(Enum):
    DECODER = auto()  # Decoder attention between previous layer Q/K/V
    ENCODER = auto()  # Encoder attention between previous layer Q/K/V
    ENCODER_DECODER = auto()  # Attention between dec. Q and enc. K/V


19
20
21
class AttentionBackend(ABC):
    """Abstract class for attention backends."""

22
23
24
25
26
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

27
28
29
30
31
32
33
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
34
    def get_metadata_cls() -> Type["AttentionMetadata"]:
35
36
        raise NotImplementedError

37
38
39
40
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

41
42
43
44
45
46
47
48
49
50
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

    @classmethod
    def make_metadata_builder(cls, *args,
                              **kwargs) -> "AttentionMetadataBuilder":
        return cls.get_builder_cls()(*args, **kwargs)

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
66
        src_to_dst: torch.Tensor,
67
68
69
70
71
72
73
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
74
        src_to_dists: torch.Tensor,
75
76
77
    ) -> None:
        raise NotImplementedError

78
79
80
    def advance_step(self, num_seqs: int, num_queries: int):
        raise NotImplementedError

81
82

@dataclass
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
111

112
113
114
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
115
        """Similar to dataclasses.asdict, but avoids deepcopying."""
116
117
        if skip_fields is None:
            skip_fields = set()
118
119
120
121
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
122
            for field in fields(self) if field.name not in skip_fields
123
124
125
        }


126
T = TypeVar("T", bound=AttentionMetadata)
127
128


129
130
131
132
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
133
    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
134
135
136
        raise NotImplementedError

    @abstractmethod
137
138
    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> T:
139
140
141
142
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


143
class AttentionImpl(ABC, Generic[T]):
144
145
146
147
148
149
150
151
152
153

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
154
        kv_cache_dtype: str = "auto",
155
        blocksparse_params: Optional[Dict[str, Any]] = None,
156
        logits_soft_cap: Optional[float] = None,
157
158
159
160
161
162
163
164
165
166
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
167
        attn_metadata: T,
168
169
        k_scale: float = 1.0,
        v_scale: float = 1.0,
170
        attn_type: AttentionType = AttentionType.DECODER,
171
172
    ) -> torch.Tensor:
        raise NotImplementedError