abstract.py 6.66 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from contextlib import contextmanager
3
from dataclasses import dataclass, fields
4
from enum import Enum, auto
5
6
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
                    Tuple, Type, TypeVar)
7
8
9

import torch

10
if TYPE_CHECKING:
11
12
13
    from vllm.worker.model_runner_base import (ModelRunnerBase,
                                               ModelRunnerInputBase,
                                               ModelRunnerInputBuilderBase)
14

15

16
17
18
19
20
21
class AttentionType(Enum):
    DECODER = auto()  # Decoder attention between previous layer Q/K/V
    ENCODER = auto()  # Encoder attention between previous layer Q/K/V
    ENCODER_DECODER = auto()  # Attention between dec. Q and enc. K/V


22
23
24
class AttentionBackend(ABC):
    """Abstract class for attention backends."""

25
26
27
28
29
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

30
31
32
33
34
35
36
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
37
    def get_metadata_cls() -> Type["AttentionMetadata"]:
38
39
        raise NotImplementedError

40
41
42
43
44
    @staticmethod
    @abstractmethod
    def get_state_cls() -> Type["AttentionState"]:
        raise NotImplementedError

45
46
47
48
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

49
50
51
52
53
54
55
56
57
58
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

    @classmethod
    def make_metadata_builder(cls, *args,
                              **kwargs) -> "AttentionMetadataBuilder":
        return cls.get_builder_cls()(*args, **kwargs)

59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
74
        src_to_dst: torch.Tensor,
75
76
77
78
79
80
81
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
82
        src_to_dists: torch.Tensor,
83
84
85
    ) -> None:
        raise NotImplementedError

86
87
88
    def advance_step(self, num_seqs: int, num_queries: int):
        raise NotImplementedError

89
90

@dataclass
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
119

120
121
122
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
123
        """Similar to dataclasses.asdict, but avoids deepcopying."""
124
125
        if skip_fields is None:
            skip_fields = set()
126
127
128
129
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
130
            for field in fields(self) if field.name not in skip_fields
131
132
133
        }


134
T = TypeVar("T", bound=AttentionMetadata)
135
136


137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class AttentionState(ABC, Generic[T]):
    """Holds attention backend-specific objects reused during the
    lifetime of the model runner."""

    @abstractmethod
    def __init__(self, runner: "ModelRunnerBase"):
        ...

    @abstractmethod
    @contextmanager
    def graph_capture(self, max_batch_size: int):
        """Context manager used when capturing CUDA graphs."""
        yield

    @abstractmethod
    def graph_clone(self, batch_size: int) -> "AttentionState[T]":
        """Clone attention state to save in CUDA graph metadata."""
        ...

    @abstractmethod
    def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
        """Get attention metadata for CUDA graph capture of batch_size."""
        ...

    @abstractmethod
    def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
        """Get attention-specific input buffers for CUDA graph capture."""
        ...

    @abstractmethod
    def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
                                    attn_metadata: T) -> None:
        """In-place modify input buffers dict for CUDA graph replay."""
        ...

    @abstractmethod
    def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
        """Prepare state for forward pass."""
        ...


178
179
180
181
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
182
    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
183
184
185
        raise NotImplementedError

    @abstractmethod
186
187
    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> T:
188
189
190
191
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


192
class AttentionImpl(ABC, Generic[T]):
193
194
195
196
197
198
199
200
201
202

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
203
        kv_cache_dtype: str = "auto",
204
        blocksparse_params: Optional[Dict[str, Any]] = None,
205
        logits_soft_cap: Optional[float] = None,
206
207
208
209
210
211
212
213
214
215
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
216
        attn_metadata: T,
217
218
        k_scale: float = 1.0,
        v_scale: float = 1.0,
219
        attn_type: AttentionType = AttentionType.DECODER,
220
221
    ) -> torch.Tensor:
        raise NotImplementedError