attention.py 9.13 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
from cacheflow import attention_ops
from cacheflow import cache_ops
10
from cacheflow import pos_encoding_ops
11
from cacheflow.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
14

15

16
class GPTCacheFlowAttention(nn.Module):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    """GPT-style multi-head attention.

    This class takes flattened 1D query, key, and value tensors as input. The
    input 1D tensors can be split into three parts: the prompt tokens, the
    generation tokens, and the paddings.

    |<------------------------------------- num_valid_tokens ------------------------------------->|
    |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
    def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
45
        super().__init__()
46
47
        self.num_heads = num_heads
        self.head_size = head_size
48
        self.scale = float(scale)
49
        self.attn_op = xops.fmha.cutlass.FwOp()
Woosuk Kwon's avatar
Woosuk Kwon committed
50

51
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
54

Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
    def multi_query_kv_attention(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
59
60
        output: torch.Tensor,                   # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,                      # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
61
        attn_bias: xops.AttentionBias,
62
    ) -> torch.Tensor:
63
64
65
66
67
68
69
70
71
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        )
73
74
75
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78

    def single_query_cached_kv_attention(
        self,
79
80
81
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
82
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
        input_metadata: InputMetadata,
    ) -> None:
85
86
87
88
89
90
91
92
93
94
95
96
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99

    def forward(
        self,
100
101
102
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
103
104
        key_cache: Optional[torch.Tensor],      # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: Optional[torch.Tensor],    # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
107
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
        # NOTE: The query, key, and value tensors must be sliced from a qkv
        # tensor of shape [num_tokens, 3 * num_heads * head_size].
Woosuk Kwon's avatar
Woosuk Kwon committed
110

Woosuk Kwon's avatar
Woosuk Kwon committed
111
        # Reshape the query, key, and value tensors.
112
113
114
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_heads, self.head_size)
        value = value.view(-1, self.num_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
122
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
125
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
126
                value[:num_prompt_tokens],
127
                input_metadata.attn_bias,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
129
130
131
132
133
134

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
135
136
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
137
        num_valid_tokens = input_metadata.num_valid_tokens
138
139
        if (num_valid_tokens > 0 and key_cache is not None
            and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
143
144
145
146
147
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
148
149

        if input_metadata.num_generation_tokens > 0:
150
151
152
153
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
                "generating tokens."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
154
155
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
                output[num_prompt_tokens:num_valid_tokens],
                query[num_prompt_tokens:num_valid_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
160
161
162
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
163
        # NOTE(woosuk): The output tensor may include paddings.
164
        return output.view(-1, self.num_heads * self.head_size)
165
166


167
168
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
    """Attention with GPT-NeoX style rotary embedding."""
169
170
171

    def __init__(
        self,
172
173
        num_heads: int,
        head_size: int,
174
        scale: float,
175
        rotary_dim: int,
176
177
178
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
179
        super().__init__(num_heads, head_size, scale)
180
181

        # Create the cos and sin cache.
182
        inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
183
184
185
186
187
188
189
190
191
192
        t = torch.arange(max_position).float()
        freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
        # initializing the model. Make it more robust.
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
193
        # Embedding size: [max_position, rotary_dim]
194
        self.register_buffer("cos_sin_cache", cache, persistent=False)
195
196
197

    def forward(
        self,
198
        positions: torch.Tensor,                # [num_tokens]
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
213
            self.head_size,
214
215
216
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
            query,
            key,
219
220
221
222
223
224
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )