"deploy/operator/internal/consts/consts.go" did not exist on "93acc631da3ecf057ae25392e0612368b2ab1243"
abstract.py 3.49 KB
Newer Older
1
2
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
3
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

import torch


class AttentionBackend(ABC):
    """Abstract class for attention backends."""

    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def make_metadata(*args, **kwargs) -> "AttentionMetadata":
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
        src_to_dst: Dict[int, int],
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
        src_to_dists: Dict[int, List[int]],
    ) -> None:
        raise NotImplementedError


@dataclass
50
51
class AttentionMetadataPerStage:
    """Attention metadata for a specific stage. I.e., prefill or decode."""
52
53
54
55
56
57
58
59
60
61
62

    def asdict_zerocopy(self) -> Dict[str, Any]:
        """Similar to dataclasses.asdict, but avoids deepcopying."""
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
            for field in fields(self)
        }


63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
T = TypeVar("T", bound=AttentionMetadataPerStage)


@dataclass
class AttentionMetadata(Generic[T]):
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # The attention metadata for prefill requests in a batch.
    # None if there's no prefill requests in a batch.
    prefill_metadata: Optional[T]
    # The attention metadata for decode requests in a batch.
    # None if there's no decode requests in a batch.
    decode_metadata: Optional[T]
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor
    # The kv cache's data type.
    kv_cache_dtype: str

    def __post_init__(self):
        if self.num_prefill_tokens > 0:
            assert self.num_prefills > 0
            assert self.prefill_metadata is not None
        if self.num_decode_tokens > 0:
            assert self.decode_metadata is not None


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class AttentionImpl(ABC):

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
119
        attn_metadata: AttentionMetadata,
120
        kv_scale: float,
121
122
    ) -> torch.Tensor:
        raise NotImplementedError