attention.py 7.97 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2

Woosuk Kwon's avatar
Woosuk Kwon committed
3
from flash_attn.flash_attn_interface import _flash_attn_forward
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
import torch
import torch.nn as nn

7
8
from cacheflow import attention_ops
from cacheflow import cache_ops
9
from cacheflow import pos_encoding_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
from cacheflow.models import InputMetadata


13
class GPTCacheFlowAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15

    def __init__(self, scale: float) -> None:
16
        super().__init__()
17
        self.scale = float(scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20

    def multi_query_kv_attention(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25
26
        output: torch.Tensor,                   # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,                      # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
        cumulative_prompt_lens: torch.Tensor,   # [num_prompts + 1]
        max_prompt_len: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    ) -> None:
28
29
30
        if query.dtype == torch.float:
            raise ValueError('The float data type is not supported by '
                             'FlashAttention. Use the half data type instead.')
Woosuk Kwon's avatar
Woosuk Kwon committed
31
        head_size = query.shape[-1]
32
33
34
        if head_size > 128:
            raise ValueError('FlashAttention does not support head_size > 128.')

Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37
38
39
40
41
42
43
44
45
46
47
        # Directly call FlashAttention's internal function to avoid allocating
        # a new tensor for the output.
        _flash_attn_forward(
            query,
            key,
            value,
            output,
            cumulative_prompt_lens,
            cumulative_prompt_lens,
            max_prompt_len,
            max_prompt_len,
            dropout_p=0.0,
            softmax_scale=self.scale,
48
            causal=True,
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
            return_softmax=False,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53

    def single_query_cached_kv_attention(
        self,
54
55
56
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
57
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
        input_metadata: InputMetadata,
    ) -> None:
60
61
62
63
64
65
66
67
        head_size = value_cache.shape[2]
        supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
        if head_size not in supported_head_sizes:
            raise ValueError(f'head_size ({head_size}) is not supported by '
                             'the single_query_cached_kv_attention kernel. '
                             'Use one of the following head sizes: '
                             f'{supported_head_sizes}.')

68
69
70
71
72
73
74
75
76
77
78
79
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82

    def forward(
        self,
83
84
85
86
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
87
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
90
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
        # NOTE: The query, key, and value tensors must be sliced from a qkv
        # tensor of shape [num_tokens, 3 * num_heads * head_size].
Woosuk Kwon's avatar
Woosuk Kwon committed
93

Woosuk Kwon's avatar
Woosuk Kwon committed
94
        # Reshape the query, key, and value tensors.
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        num_heads = value_cache.shape[1]
96
        head_size = value_cache.shape[2]
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
100
101
102

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
107
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
111
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
                input_metadata.cumulative_prompt_lens,
                input_metadata.max_prompt_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
114
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117
118
119
120

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
125
126
127
128
129
130
        num_valid_tokens = input_metadata.num_valid_tokens
        if num_valid_tokens > 0:
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
                output[num_prompt_tokens:num_valid_tokens],
                query[num_prompt_tokens:num_valid_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
142
        # NOTE(woosuk): The output tensor may include paddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
143
        return output.view(-1, num_heads * head_size)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199


class OPTCacheFlowAttention(GPTCacheFlowAttention):
    """OPT uses the same attention mechanism as GPT."""

    def __init__(self, scale: float) -> None:
        super().__init__(scale)


class LlamaCacheFlowAttention(GPTCacheFlowAttention):
    """Llama uses GPT-NeoX style rotary embedding."""

    def __init__(
        self,
        scale: float,
        head_size: int,
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
        super().__init__(scale)

        # Create the cos and sin cache.
        inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
        t = torch.arange(max_position).float()
        freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
        # initializing the model. Make it more robust.
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
        # Embedding size: [max_position, head_size]
        self.register_buffer('cos_sin_cache', cache, persistent=False)

    def forward(
        self,
        positions: torch.LongTensor,            # [num_tokens]
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
            query,
            key,
202
203
204
205
206
207
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )