"tests/tool_parsers/test_jamba_tool_parser.py" did not exist on "bc546f76a145087ceae59e842443193aaf8a91a0"
abstract.py 8.45 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from contextlib import contextmanager
3
from dataclasses import dataclass, fields
4
5
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
                    Protocol, Set, Tuple, Type, TypeVar)
6
7
8

import torch

9
10
from vllm.multimodal import MultiModalPlaceholderMap

11
if TYPE_CHECKING:
12
13
14
    from vllm.worker.model_runner_base import (ModelRunnerBase,
                                               ModelRunnerInputBase,
                                               ModelRunnerInputBuilderBase)
15

16

17
18
19
20
21
22
23
24
25
26
27
28
29
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
    # Decoder attention between previous layer Q/K/V
    DECODER = "decoder"
    # Encoder attention between previous layer Q/K/V for encoder-decoder
    ENCODER = "encoder"
    # Encoder attention between previous layer Q/K/V
    ENCODER_ONLY = "encoder_only"
    # Attention between dec. Q and enc. K/V for encoder-decoder
    ENCODER_DECODER = "encoder_decoder"
30
31


32
33
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
34
35
36
37
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
38

39
40
41
42
43
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

44
45
46
47
48
49
50
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
51
    def get_metadata_cls() -> Type["AttentionMetadata"]:
52
53
        raise NotImplementedError

54
55
56
57
58
    @staticmethod
    @abstractmethod
    def get_state_cls() -> Type["AttentionState"]:
        raise NotImplementedError

59
60
61
62
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

63
64
65
66
67
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
83
        src_to_dst: torch.Tensor,
84
85
86
87
88
89
90
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
91
        src_to_dists: torch.Tensor,
92
93
94
    ) -> None:
        raise NotImplementedError

95
96
97
    def advance_step(self, model_input: "ModelRunnerInputBase",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int, num_seqs: int, num_queries: int) -> None:
98
99
        raise NotImplementedError

100
101

@dataclass
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

117
118
119
120
121
122
123
124
125
    # The index maps that relate multi-modal embeddings to the corresponding
    # placeholders.
    #
    # N.B. These aren't really related to attention and don't belong on this
    # type -- this is just a temporary solution to make them available to
    # `model_executable`.
    multi_modal_placeholder_index_maps: Optional[Dict[
        str, MultiModalPlaceholderMap.IndexMap]]

126
127
128
129
    # Enable/disable KV scales calculation. This is so that we can disable the
    # calculation until after prefill and cuda graph capture.
    enable_kv_scales_calculation: bool

130
131
132
133
134
135
136
137
138
139
140
141
142
    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
143

144
145
146
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
147
        """Similar to dataclasses.asdict, but avoids deepcopying."""
148
149
        if skip_fields is None:
            skip_fields = set()
150
151
152
153
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
154
            for field in fields(self) if field.name not in skip_fields
155
156
157
        }


158
T = TypeVar("T", bound=AttentionMetadata)
159
160


161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
class AttentionState(ABC, Generic[T]):
    """Holds attention backend-specific objects reused during the
    lifetime of the model runner."""

    @abstractmethod
    def __init__(self, runner: "ModelRunnerBase"):
        ...

    @abstractmethod
    @contextmanager
    def graph_capture(self, max_batch_size: int):
        """Context manager used when capturing CUDA graphs."""
        yield

    @abstractmethod
    def graph_clone(self, batch_size: int) -> "AttentionState[T]":
        """Clone attention state to save in CUDA graph metadata."""
        ...

    @abstractmethod
181
182
183
184
    def graph_capture_get_metadata_for_batch(
            self,
            batch_size: int,
            is_encoder_decoder_model: bool = False) -> T:
185
186
187
188
        """Get attention metadata for CUDA graph capture of batch_size."""
        ...

    @abstractmethod
189
190
191
192
    def get_graph_input_buffers(
            self,
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
193
194
195
196
        """Get attention-specific input buffers for CUDA graph capture."""
        ...

    @abstractmethod
197
198
199
200
201
    def prepare_graph_input_buffers(
            self,
            input_buffers: Dict[str, Any],
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> None:
202
203
204
205
206
207
208
209
210
        """In-place modify input buffers dict for CUDA graph replay."""
        ...

    @abstractmethod
    def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
        """Prepare state for forward pass."""
        ...


211
212
213
214
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
215
    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
216
217
218
219
220
221
        """Create the builder, remember some configuration and parameters."""
        raise NotImplementedError

    @abstractmethod
    def prepare(self) -> None:
        """Prepare for one batch."""
222
223
224
        raise NotImplementedError

    @abstractmethod
225
226
    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> T:
227
228
229
230
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


231
232
class AttentionLayer(Protocol):

233
234
235
236
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
    _k_scale_float: float
    _v_scale_float: float
237
238
239
240
241
242
243
244
245
246
247
248

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        ...


249
class AttentionImpl(ABC, Generic[T]):
250
251
252
253
254
255
256
257
258
259

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
260
        kv_cache_dtype: str = "auto",
261
        blocksparse_params: Optional[Dict[str, Any]] = None,
262
        logits_soft_cap: Optional[float] = None,
263
        attn_type: str = AttentionType.DECODER,
264
265
266
267
268
269
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
270
        layer: AttentionLayer,
271
272
273
274
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
275
        attn_metadata: T,
276
        output: Optional[torch.Tensor] = None,
277
278
    ) -> torch.Tensor:
        raise NotImplementedError