test_attention.py 10.6 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
import torch
6
7
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
8

9
from vllm._C import ops
10
from vllm.utils import get_max_shared_memory_bytes
11

12
13
14
15
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
16
NUM_BLOCKS = 40000  # Arbitrary values for testing
17
PARTITION_SIZE = 512
18
19
20

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
21
NUM_PREFILL_SEQS = [3]  # Arbitrary values for testing
22
23
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
24
BLOCK_SIZES = [16, 32]
25
USE_ALIBI = [False, True]
26
SEEDS = [0]
27

28
29
30
31
32
33
34
35

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
36
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
37
    if attn_mask is not None:
38
39
40
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
41
42
43
44
45
46
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
47
    num_queries_per_kv: int,
48
49
50
51
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
52
53
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
54
) -> None:
55
56
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
57
58
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
59
    num_seqs = query.shape[0]
60

61
62
63
    block_tables = block_tables.cpu().tolist()
    context_lens = context_lens.cpu().tolist()
    for i in range(num_seqs):
64
65
66
67
68
69
70
71
72
73
74
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
75
            k = k.reshape(num_kv_heads, head_size)
76
77
78
79
80
81
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)
82
83
84
85
86
87
88
89
90
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
            position_ids = torch.arange(context_len, device="cuda").int()
91
            alibi_bias = (position_ids - context_len + 1).float()
92
93
94
95
96
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
97
98
99
        output[i].copy_(out, non_blocking=True)


100
@pytest.mark.parametrize("version", ["v1", "v2"])
101
102
103
104
105
106
107
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
108
def test_paged_attention(
109
    kv_cache_factory,
110
    version: str,
111
112
    num_seqs: int,
    num_heads: Tuple[int, int],
113
    head_size: int,
114
    use_alibi: bool,
115
116
    block_size: int,
    dtype: torch.dtype,
117
    seed: int,
118
) -> None:
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
                        device="cuda")
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
                                   device="cuda")

    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
141
    context_lens[-1] = MAX_SEQ_LEN
142
    max_context_len = max(context_lens)
143
    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
144

145
    # Create the block tables.
146
147
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
148
    for _ in range(num_seqs):
149
        block_table = [
150
            random.randint(0, NUM_BLOCKS - 1)
151
152
153
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
154
    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
155

156
157
158
159
160
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                num_kv_heads, head_size, dtype,
                                                seed)
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
161

162
163
    # Call the paged attention kernel.
    output = torch.empty_like(query)
164
    if version == "v1":
165
        ops.paged_attention_v1(
166
167
168
169
            output,
            query,
            key_cache,
            value_cache,
170
            num_kv_heads,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
    elif version == "v2":
        num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
                          PARTITION_SIZE)
        assert PARTITION_SIZE % block_size == 0
        num_seqs, num_heads, head_size = output.shape
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
        )
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
194
        ops.paged_attention_v2(
195
196
197
198
199
200
201
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
202
            num_kv_heads,
203
204
205
206
207
208
209
210
            scale,
            block_tables,
            context_lens,
            block_size,
            max_context_len,
            alibi_slopes,
        )
    else:
211
        raise AssertionError(f"Unknown version: {version}")
212

213
    # Run the reference implementation.
214
215
216
217
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
218
        num_queries_per_kv,
219
220
221
222
        key_cache,
        value_cache,
        block_tables,
        context_lens,
223
224
        scale,
        alibi_slopes,
225
    )
226
227
228
229

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
230
231
232
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
    ref_outputs = []
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
        attn_mask = attn_mask.to(dtype=dtype, device="cuda")

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


266
# TODO(woosuk): Add tests for USE_ALIBI=True.
267
268
269
270
271
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
272
@torch.inference_mode()
273
def test_multi_query_kv_attention(
274
    num_seqs: int,
275
    num_heads: Tuple[int, int],
276
277
    head_size: int,
    dtype: torch.dtype,
278
    seed: int,
279
) -> None:
280
281
282
283
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

284
285
286
287
288
    # MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
    # As the xformers library is already tested with its own tests, we can use
    # a smaller MAX_SEQ_LEN here.
    max_len = min(MAX_SEQ_LEN, 4096)
    seq_lens = random.sample(range(1, max_len), num_seqs)
289
290
    num_tokens = sum(seq_lens)

291
    scale = float(1.0 / (head_size**0.5))
292
    num_query_heads, num_kv_heads = num_heads
293
    qkv = torch.empty(num_tokens,
294
                      num_query_heads + 2 * num_kv_heads,
295
296
                      head_size,
                      dtype=dtype,
297
298
299
300
301
302
303
304
305
306
                      device="cuda")
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
307
308
309
310
311
312
313
314
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
315
    )
316
    output = output.squeeze(0)
317

318
319
320
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
321
322
323
324
325
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
326
        scale,
327
328
        dtype,
    )
329
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)