test_attention.py 9.12 KB
Newer Older
1
import random
2
from typing import List, Optional
3
4

import torch
5
6
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm import attention_ops
9

10
MAX_SEQ_LEN = 4096
11
TEST_SEED = 0
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    query = query * scale
    attn = torch.einsum('qhd,khd->hqk', query, key)
    if attn_mask is not None:
        attn = attn + attn_mask
    attn = torch.softmax(attn, dim=-1)
    out = torch.einsum('hqk,khd->qhd', attn, value)
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
) -> None:
    num_heads = value_cache.shape[1]
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]

    num_input_tokens = query.shape[0]
    for i in range(num_input_tokens):
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
            k = k.reshape(num_heads, head_size)
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)

        scale = 1.0 / (head_size ** 0.5)
        out = ref_masked_attention(q, keys, values, scale)
        out = out.view(num_heads, head_size)
        output[i].copy_(out, non_blocking=True)


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    dtype: torch.dtype,
) -> torch.Tensor:
    head_size = query.shape[-1]
    scale = 1.0 / (head_size ** 0.5)

    num_seqs = len(cu_seq_lens) - 1
    ref_outputs = []
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

86
87
88
89
        # Create attention mask.
        attn_mask = torch.triu(
            torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        attn_mask = attn_mask.to(dtype=dtype, device='cuda')

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def ref_multi_query_cached_kv_attention(
    cu_query_lens: List[int],
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_heads = value_cache.shape[1]
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
    scale = 1.0 / (head_size ** 0.5)

    num_queries = len(cu_query_lens) - 1
    ref_outputs = []
    for i in range(num_queries):
        start_idx = cu_query_lens[i]
        end_idx = cu_query_lens[i + 1]
        query_len = end_idx - start_idx
        context_len = int(context_lens[i])
        block_table = block_tables[i]

        # Create attention mask
        attn_mask = torch.triu(
            torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
        attn_mask = attn_mask.to(dtype=dtype, device='cuda')

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
            k = k.reshape(num_heads, head_size)
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            keys,
            values,
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


159
160
@torch.inference_mode()
def run_single_query_cached_kv_attention(
161
162
163
164
165
166
167
    num_tokens: int,
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
168
    qkv = torch.empty(
Woosuk Kwon's avatar
Woosuk Kwon committed
169
        num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
170
    qkv.uniform_(-1e-3, 1e-3)
Woosuk Kwon's avatar
Woosuk Kwon committed
171
    query, _, _ = qkv.unbind(dim=1)
172

173
174
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_block_shape = (num_heads, head_size // x, block_size, x)
175
    key_cache = torch.empty(
176
        size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
177
    key_cache.uniform_(-1e-3, 1e-3)
178
    value_block_shape = (num_heads, head_size, block_size)
179
    value_cache = torch.empty(
180
        size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
181
    value_cache.uniform_(-1e-3, 1e-3)
Woosuk Kwon's avatar
Woosuk Kwon committed
182

183
    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] 
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    max_context_len = max(context_lens)
    context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')

    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
    for _ in range(num_tokens):
        block_table = [
            random.randint(0, num_blocks - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
    block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')

    scale = float(1.0 / (head_size ** 0.5))
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
    output = torch.empty(
        num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    attention_ops.single_query_cached_kv_attention(
        output,
        query,
        key_cache,
        value_cache,
        scale,
        block_tables,
        context_lens,
        block_size,
        max_context_len,
    )

    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
        key_cache,
        value_cache,
        block_tables,
        context_lens,
    )
    # NOTE(woosuk): Due to the difference in the data types the two
    # implementations use for attention softmax logits and accumulation,
    # there is a small difference in the final outputs.
    # We should use a relaxed tolerance for the test.
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


228
229
@torch.inference_mode()
def run_multi_query_kv_attention(
230
231
232
233
234
235
236
237
238
    num_seqs: int,
    num_heads: int,
    head_size: int,
    dtype: torch.dtype,
) -> None:
    seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
    num_tokens = sum(seq_lens)

    scale = float(1.0 / (head_size ** 0.5))
239
    qkv = torch.empty(
Woosuk Kwon's avatar
Woosuk Kwon committed
240
        num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
241
    qkv.uniform_(-1e-3, 1e-3)
Woosuk Kwon's avatar
Woosuk Kwon committed
242
    query, key, value = qkv.unbind(dim=1)
243
244
245
246
247
248
249
250
251
252
253

    attn_op = xops.fmha.cutlass.FwOp()
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
        op=attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
254
    )
255
    output = output.squeeze(0)
256

257
258
259
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
260
261
262
263
264
265
266
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
        dtype,
    )
267
268
269
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


270
271
272
def test_single_query_cached_kv_attention() -> None:
    torch.random.manual_seed(TEST_SEED)
    torch.cuda.manual_seed(TEST_SEED)
Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
275
    for dtype in [torch.half, torch.bfloat16, torch.float]:
        for block_size in [8, 16, 32]:
            for head_size in [64, 80, 96, 128]:
276
277
278
                print(f'Testing single_query_cached_kv_attention with '
                      f'dtype={dtype}, block_size={block_size}, '
                      f'head_size={head_size}')
279
                run_single_query_cached_kv_attention(
280
281
282
283
284
285
286
287
                    num_tokens=37,
                    num_heads=3,
                    head_size=head_size,
                    block_size=block_size,
                    num_blocks=1024,
                    dtype=dtype,
                )

288
289
290
291

def test_multi_query_kv_attention() -> None:
    torch.random.manual_seed(TEST_SEED)
    torch.cuda.manual_seed(TEST_SEED)
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
    for dtype in [torch.half, torch.bfloat16, torch.float]:
        for head_size in [64, 80, 96, 128]:
294
295
            print(f'Testing multi_query_kv_attention with dtype={dtype}, '
                  f'head_size={head_size}')
296
            run_multi_query_kv_attention(
297
                num_seqs=5,
298
299
300
301
                num_heads=3,
                head_size=head_size,
                dtype=dtype,
            )