attention.py 8.84 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
from cacheflow import attention_ops
from cacheflow import cache_ops
10
from cacheflow import pos_encoding_ops
11
from cacheflow.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
12
13


14
class GPTCacheFlowAttention(nn.Module):
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    """GPT-style multi-head attention.

    This class takes flattened 1D query, key, and value tensors as input. The
    input 1D tensors can be split into three parts: the prompt tokens, the
    generation tokens, and the paddings.

    |<------------------------------------- num_valid_tokens ------------------------------------->|
    |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42

    def __init__(self, scale: float) -> None:
43
        super().__init__()
44
        self.scale = float(scale)
45
        self.attn_op = xops.fmha.cutlass.FwOp()
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48

    def multi_query_kv_attention(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
52
        output: torch.Tensor,                   # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,                      # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
53
        attn_bias: xops.AttentionBias,
Woosuk Kwon's avatar
Woosuk Kwon committed
54
    ) -> None:
55
56
57
58
59
60
61
62
63
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
64
        )
65
66
67
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70

    def single_query_cached_kv_attention(
        self,
71
72
73
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
74
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
        input_metadata: InputMetadata,
    ) -> None:
77
78
79
80
81
82
83
84
        head_size = value_cache.shape[2]
        supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
        if head_size not in supported_head_sizes:
            raise ValueError(f'head_size ({head_size}) is not supported by '
                             'the single_query_cached_kv_attention kernel. '
                             'Use one of the following head sizes: '
                             f'{supported_head_sizes}.')

85
86
87
88
89
90
91
92
93
94
95
96
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99

    def forward(
        self,
100
101
102
103
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
104
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
107
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
        # NOTE: The query, key, and value tensors must be sliced from a qkv
        # tensor of shape [num_tokens, 3 * num_heads * head_size].
Woosuk Kwon's avatar
Woosuk Kwon committed
110

Woosuk Kwon's avatar
Woosuk Kwon committed
111
        # Reshape the query, key, and value tensors.
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        num_heads = value_cache.shape[1]
113
        head_size = value_cache.shape[2]
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
124
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
128
                value[:num_prompt_tokens],
129
                input_metadata.attn_bias,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
135
136

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
142
143
144
145
146
        num_valid_tokens = input_metadata.num_valid_tokens
        if num_valid_tokens > 0:
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
150

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
                output[num_prompt_tokens:num_valid_tokens],
                query[num_prompt_tokens:num_valid_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154
155
156
157
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
158
        # NOTE(woosuk): The output tensor may include paddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        return output.view(-1, num_heads * head_size)
160
161


162
163
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
    """Attention with GPT-NeoX style rotary embedding."""
164
165
166
167

    def __init__(
        self,
        scale: float,
168
        rotary_dim: int,
169
170
171
172
173
174
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
        super().__init__(scale)

        # Create the cos and sin cache.
175
        inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
176
177
178
179
180
181
182
183
184
185
        t = torch.arange(max_position).float()
        freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
        # initializing the model. Make it more robust.
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
186
        # Embedding size: [max_position, rotary_dim]
187
        self.register_buffer("cos_sin_cache", cache, persistent=False)
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    def forward(
        self,
        positions: torch.LongTensor,            # [num_tokens]
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
202
        head_size = value_cache.shape[2]
203
204
205
206
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
207
            head_size,
208
209
210
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
            query,
            key,
213
214
215
216
217
218
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )