abstract.py 3.56 KB
Newer Older
1
2
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
3
4
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type,
                    TypeVar)
5
6
7
8
9
10
11

import torch


class AttentionBackend(ABC):
    """Abstract class for attention backends."""

12
13
14
15
16
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

17
18
19
20
21
22
23
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
24
    def make_metadata(*args, **kwargs) -> "AttentionMetadata":
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
42
        src_to_dst: torch.Tensor,
43
44
45
46
47
48
49
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
50
        src_to_dists: torch.Tensor,
51
52
53
54
55
    ) -> None:
        raise NotImplementedError


@dataclass
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
84

85
86
87
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
88
        """Similar to dataclasses.asdict, but avoids deepcopying."""
89
90
        if skip_fields is None:
            skip_fields = set()
91
92
93
94
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
95
            for field in fields(self) if field.name not in skip_fields
96
97
98
        }


99
T = TypeVar("T", bound=AttentionMetadata)
100
101


102
class AttentionImpl(ABC, Generic[T]):
103
104
105
106
107
108
109
110
111
112

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
113
        kv_cache_dtype: str = "auto",
114
        blocksparse_params: Optional[Dict[str, Any]] = None,
115
116
117
118
119
120
121
122
123
124
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
125
        attn_metadata: T,
126
        kv_scale: float = 1.0,
127
128
    ) -> torch.Tensor:
        raise NotImplementedError