test_attention.py 10.7 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
import torch
6
7
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
8

9
from vllm._C import ops
10
from vllm.utils import get_max_shared_memory_bytes
11

12
13
14
15
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
16
NUM_BLOCKS = 40000  # Arbitrary values for testing
17
PARTITION_SIZE = 512
18
19
20

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
21
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
22
23
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
24
BLOCK_SIZES = [16, 32]
25
USE_ALIBI = [False, True]
26
SEEDS = [0]
27

28
29
30
31
32
33
34
35

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
36
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
37
    if attn_mask is not None:
38
39
40
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
41
42
43
44
45
46
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
47
    num_queries_per_kv: int,
48
49
50
51
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
52
53
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
54
) -> None:
55
56
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
57
58
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
59
    num_seqs = query.shape[0]
60

61
62
63
    block_tables = block_tables.cpu().tolist()
    context_lens = context_lens.cpu().tolist()
    for i in range(num_seqs):
64
65
66
67
68
69
70
71
72
73
74
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
75
            k = k.reshape(num_kv_heads, head_size)
76
77
78
79
80
81
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)
82
83
84
85
86
87
88
89
90
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
            position_ids = torch.arange(context_len, device="cuda").int()
91
            alibi_bias = (position_ids - context_len + 1).float()
92
93
94
95
96
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
97
98
99
        output[i].copy_(out, non_blocking=True)


100
@pytest.mark.parametrize("version", ["v1", "v2"])
101
102
103
104
105
106
107
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
108
def test_paged_attention(
109
    kv_cache_factory,
110
    version: str,
111
112
    num_seqs: int,
    num_heads: Tuple[int, int],
113
    head_size: int,
114
    use_alibi: bool,
115
116
    block_size: int,
    dtype: torch.dtype,
117
    seed: int,
118
) -> None:
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
                        device="cuda")
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    head_mapping = torch.repeat_interleave(
        torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
        num_queries_per_kv)
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
                                   device="cuda")

    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
144
    context_lens[-1] = MAX_SEQ_LEN
145
    max_context_len = max(context_lens)
146
    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
147

148
    # Create the block tables.
149
150
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
151
    for _ in range(num_seqs):
152
        block_table = [
153
            random.randint(0, NUM_BLOCKS - 1)
154
155
156
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
157
    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
158

159
160
161
162
163
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                num_kv_heads, head_size, dtype,
                                                seed)
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
164

165
166
    # Call the paged attention kernel.
    output = torch.empty_like(query)
167
    if version == "v1":
168
        ops.paged_attention_v1(
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
            output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
    elif version == "v2":
        num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
                          PARTITION_SIZE)
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
197
        ops.paged_attention_v2(
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
    else:
214
        raise AssertionError(f"Unknown version: {version}")
215

216
    # Run the reference implementation.
217
218
219
220
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
221
        num_queries_per_kv,
222
223
224
225
        key_cache,
        value_cache,
        block_tables,
        context_lens,
226
227
        scale,
        alibi_slopes,
228
    )
229
230
231
232

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
233
234
235
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
    ref_outputs = []
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
        attn_mask = attn_mask.to(dtype=dtype, device="cuda")

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


269
# TODO(woosuk): Add tests for USE_ALIBI=True.
270
271
272
273
274
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
275
@torch.inference_mode()
276
def test_multi_query_kv_attention(
277
    num_seqs: int,
278
    num_heads: Tuple[int, int],
279
280
    head_size: int,
    dtype: torch.dtype,
281
    seed: int,
282
) -> None:
283
284
285
286
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

287
288
289
290
291
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
292
293
    num_tokens = sum(seq_lens)

294
    scale = float(1.0 / (head_size**0.5))
295
    num_query_heads, num_kv_heads = num_heads
296
    qkv = torch.empty(num_tokens,
297
                      num_query_heads + 2 * num_kv_heads,
298
299
                      head_size,
                      dtype=dtype,
300
301
302
303
304
305
306
307
308
309
                      device="cuda")
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
310
311
312
313
314
315
316
317
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
318
    )
319
    output = output.squeeze(0)
320

321
322
323
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
324
325
326
327
328
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
329
        scale,
330
331
        dtype,
    )
332
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)