attention.py 7.29 KB
Newer Older
1
import random
2
from typing import List, Optional
3

4
from flash_attn.flash_attention import FlashAttention
5
6
7
8
import torch

from cacheflow import attention_ops

9
10
MAX_SEQ_LEN = 4096

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    query = query * scale
    attn = torch.einsum('qhd,khd->hqk', query, key)
    if attn_mask is not None:
        attn = attn + attn_mask
    attn = torch.softmax(attn, dim=-1)
    out = torch.einsum('hqk,khd->qhd', attn, value)
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
) -> None:
    num_heads = value_cache.shape[1]
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]

    num_input_tokens = query.shape[0]
    for i in range(num_input_tokens):
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
            k = k.reshape(num_heads, head_size)
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)

        scale = 1.0 / (head_size ** 0.5)
        out = ref_masked_attention(q, keys, values, scale)
        out = out.view(num_heads, head_size)
        output[i].copy_(out, non_blocking=True)


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    dtype: torch.dtype,
) -> torch.Tensor:
    head_size = query.shape[-1]
    scale = 1.0 / (head_size ** 0.5)

    num_seqs = len(cu_seq_lens) - 1
    ref_outputs = []
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask
        attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5
        attn_mask = attn_mask.to(dtype=dtype, device='cuda')

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def test_single_query_cached_kv_attention(
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    query = torch.randn(
        num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_block_shape = (num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(
        size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
    value_block_shape = (num_heads, head_size, block_size)
    value_cache = torch.randn(
        size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')

118
    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] 
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    max_context_len = max(context_lens)
    context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')

    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
    for _ in range(num_tokens):
        block_table = [
            random.randint(0, num_blocks - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
    block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')

    scale = float(1.0 / (head_size ** 0.5))
    output = torch.empty_like(query)
    attention_ops.single_query_cached_kv_attention(
        output,
        query,
        key_cache,
        value_cache,
        scale,
        block_tables,
        context_lens,
        block_size,
        max_context_len,
    )

    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
        key_cache,
        value_cache,
        block_tables,
        context_lens,
    )
    # NOTE(woosuk): Due to the difference in the data types the two
    # implementations use for attention softmax logits and accumulation,
    # there is a small difference in the final outputs.
    # We should use a relaxed tolerance for the test.
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def test_multi_query_kv_attention(
    num_seqs: int,
    num_heads: int,
    head_size: int,
    dtype: torch.dtype,
) -> None:
    seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
    max_seq_len = max(seq_lens)
    num_tokens = sum(seq_lens)

    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
    cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')

    scale = float(1.0 / (head_size ** 0.5))
    query = torch.randn(
        num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
    key = torch.rand_like(query)
    value = torch.rand_like(query)

    qkv = torch.stack([query, key, value], dim=1)
    flash_attn = FlashAttention(softmax_scale=scale)
    output = flash_attn(
        qkv,
        cu_seqlens=cu_seq_lens,
        max_s=max_seq_len,
        causal=True,
    )[0]

192
193
194
195
196
197
198
199
    cu_seq_lens = cu_seq_lens.cpu().tolist()
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
        dtype,
    )
200
201
202
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


203
@torch.inference_mode()
204
205
206
207
208
def test_attention(seed: int) -> None:
    # NOTE(woosuk): Even when the seed is fixed, there is a chance that
    # the test fails due to the precision issue. Re-run the test if it fails.
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)
209
210
    for dtype in [torch.half, torch.float]:
        for block_size in [8, 16]:
211
            for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
212
213
214
                print(f'Testing single_query_cached_kv_attention with '
                      f'dtype={dtype}, block_size={block_size}, '
                      f'head_size={head_size}')
215
216
217
218
219
220
221
222
223
                test_single_query_cached_kv_attention(
                    num_tokens=37,
                    num_heads=3,
                    head_size=head_size,
                    block_size=block_size,
                    num_blocks=1024,
                    dtype=dtype,
                )

224
225
226
227
    # NOTE(woosuk): FlashAttention does not support FP32.
    for dtype in [torch.half]:
        # NOTE(woosuk): FlashAttention does not support head_size > 128.
        for head_size in [64, 80, 96, 128]:
228
229
            print(f'Testing multi_query_kv_attention with dtype={dtype}, '
                  f'head_size={head_size}')
230
231
232
233
234
235
236
            test_multi_query_kv_attention(
                num_seqs=11,
                num_heads=3,
                head_size=head_size,
                dtype=dtype,
            )

237
238

if __name__ == '__main__':
239
    test_attention(seed=0)