attention.py 7.93 KB
Newer Older
1
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
2

3
from flash_attn.flash_attention import FlashAttention
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
6
import torch
import torch.nn as nn

7
8
from cacheflow import attention_ops
from cacheflow import cache_ops
9
from cacheflow import pos_encoding_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
from cacheflow.models import InputMetadata


13
class GPTCacheFlowAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15

    def __init__(self, scale: float) -> None:
16
        super().__init__()
17
        self.scale = float(scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
18

19
        self.flash_attn = FlashAttention(softmax_scale=self.scale)
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22

    def multi_query_kv_attention(
        self,
23
24
25
26
27
        output: torch.Tensor,       # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,          # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,        # [num_prompt_tokens, num_heads, head_size]
        prompt_lens: List[int],
Woosuk Kwon's avatar
Woosuk Kwon committed
28
    ) -> None:
29
30
31
32
33
34
35
36
37
        if query.dtype == torch.float:
            raise ValueError('The float data type is not supported by '
                             'FlashAttention. Use the half data type instead.')
        head_size = query.shape[2]
        if head_size > 128:
            raise ValueError('FlashAttention does not support head_size > 128.')

        device = query.device
        prefix_sum = [0]
38
        for prompt_len in prompt_lens:
39
40
41
42
43
44
45
46
47
48
49
50
51
            prefix_sum.append(prefix_sum[-1] + prompt_len)
        prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
        max_prompt_len = max(prompt_lens)

        # FIXME(woosuk): Unnecessary copy. Optimize this.
        qkv = torch.stack([query, key, value], dim=1)
        out = self.flash_attn(
            qkv,
            cu_seqlens=prefix_sum,
            max_s=max_prompt_len,
            causal=True,
        )[0]
        # FIXME(woosuk): Unnecessary copy. Optimize this.
Woosuk Kwon's avatar
Woosuk Kwon committed
52
        output.copy_(out, non_blocking=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55

    def single_query_cached_kv_attention(
        self,
56
57
58
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
59
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
        input_metadata: InputMetadata,
    ) -> None:
62
63
64
65
66
67
68
69
        head_size = value_cache.shape[2]
        supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
        if head_size not in supported_head_sizes:
            raise ValueError(f'head_size ({head_size}) is not supported by '
                             'the single_query_cached_kv_attention kernel. '
                             'Use one of the following head sizes: '
                             f'{supported_head_sizes}.')

70
71
72
73
74
75
76
77
78
79
80
81
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84

    def forward(
        self,
85
86
87
88
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
89
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
92
93
94
95
96
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Pre-allocate the output tensor.
        output = torch.empty_like(query)

        # Prune out paddings if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
        query = query[:input_metadata.num_valid_tokens]
        key = key[:input_metadata.num_valid_tokens]
        value = value[:input_metadata.num_valid_tokens]

Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
        # Reshape the input tensors.
        num_heads = value_cache.shape[1]
103
        head_size = value_cache.shape[2]
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
        query = query.view(-1, num_heads, head_size)
        key = key.view(-1, num_heads, head_size)
        value = value.view(-1, num_heads, head_size)
107
        output = output.view(-1, num_heads, head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
112
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
116
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
                input_metadata.prompt_lens,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
122
123
124

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
125
        cache_ops.reshape_and_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
130
            key, value, key_cache, value_cache, input_metadata.slot_mapping)

        if input_metadata.num_generation_tokens > 0:
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
                output[num_prompt_tokens:],
                query[num_prompt_tokens:],
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
135
136
137
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
138
        # NOTE(woosuk): The output tensor may include paddings.
Woosuk Kwon's avatar
Woosuk Kwon committed
139
        return output.view(-1, num_heads * head_size)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207


class OPTCacheFlowAttention(GPTCacheFlowAttention):
    """OPT uses the same attention mechanism as GPT."""

    def __init__(self, scale: float) -> None:
        super().__init__(scale)


class LlamaCacheFlowAttention(GPTCacheFlowAttention):
    """Llama uses GPT-NeoX style rotary embedding."""

    def __init__(
        self,
        scale: float,
        head_size: int,
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
        super().__init__(scale)

        # Create the cos and sin cache.
        inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
        t = torch.arange(max_position).float()
        freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
        # initializing the model. Make it more robust.
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
        # Embedding size: [max_position, head_size]
        self.register_buffer('cos_sin_cache', cache, persistent=False)

    def forward(
        self,
        positions: torch.LongTensor,            # [num_tokens]
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        out_query = torch.empty_like(query)
        out_key = torch.empty_like(key)
        pos_encoding_ops.rotary_embedding_neox(
            out_query,
            out_key,
            positions,
            query,
            key,
            self.cos_sin_cache,
        )
        return super().forward(
            out_query,
            out_key,
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )