test_attention.py 9.39 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
import torch
6
7
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm import attention_ops
10
from vllm.utils import get_max_shared_memory_bytes
11

12
13
14
15
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
16
17
18
19
20
21
22
23
NUM_BLOCKS = 128  # Arbitrary values for testing

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
NUM_PREFILL_SEQS = [1, 3, 7]  # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
24
USE_ALIBI = [False, True]
25
SEEDS = [0]
26

27
28
29
30
31
32
33
34

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
35
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
36
    if attn_mask is not None:
37
38
39
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
40
41
42
43
44
45
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
46
    num_queries_per_kv: int,
47
48
49
50
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
51
52
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
53
) -> None:
54
55
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
56
57
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
58
    num_seqs = query.shape[0]
59

60
61
62
    block_tables = block_tables.cpu().tolist()
    context_lens = context_lens.cpu().tolist()
    for i in range(num_seqs):
63
64
65
66
67
68
69
70
71
72
73
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
74
            k = k.reshape(num_kv_heads, head_size)
75
76
77
78
79
80
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)
81
82
83
84
85
86
87
88
89
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
            position_ids = torch.arange(context_len, device="cuda").int()
90
            alibi_bias = (position_ids - context_len + 1).float()
91
92
93
94
95
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
96
97
98
        output[i].copy_(out, non_blocking=True)


99
100
101
102
103
104
105
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
106
@torch.inference_mode()
107
108
109
110
def test_single_query_cached_kv_attention(
    kv_cache_factory,
    num_seqs: int,
    num_heads: Tuple[int, int],
111
    head_size: int,
112
    use_alibi: bool,
113
114
    block_size: int,
    dtype: torch.dtype,
115
    seed: int,
116
) -> None:
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
                        device="cuda")
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    head_mapping = torch.repeat_interleave(
        torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
        num_queries_per_kv)
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
                                   device="cuda")

    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
142
    context_lens[-1] = MAX_SEQ_LEN
143
    max_context_len = max(context_lens)
144
    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
145

146
    # Create the block tables.
147
148
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
149
    for _ in range(num_seqs):
150
        block_table = [
151
            random.randint(0, NUM_BLOCKS - 1)
152
153
154
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
155
    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
156

157
158
159
160
161
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                num_kv_heads, head_size, dtype,
                                                seed)
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
162

163
164
    # Call the paged attention kernel.
    output = torch.empty_like(query)
165
166
167
168
169
    attention_ops.single_query_cached_kv_attention(
        output,
        query,
        key_cache,
        value_cache,
170
        head_mapping,
171
172
173
174
175
        scale,
        block_tables,
        context_lens,
        block_size,
        max_context_len,
176
        alibi_slopes,
177
178
    )

179
    # Run the reference implementation.
180
181
182
183
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
184
        num_queries_per_kv,
185
186
187
188
        key_cache,
        value_cache,
        block_tables,
        context_lens,
189
190
        scale,
        alibi_slopes,
191
    )
192
193
194
195

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
196
197
198
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
    ref_outputs = []
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
        attn_mask = attn_mask.to(dtype=dtype, device="cuda")

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


232
# TODO(woosuk): Add tests for USE_ALIBI=True.
233
234
235
236
237
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
238
@torch.inference_mode()
239
def test_multi_query_kv_attention(
240
    num_seqs: int,
241
    num_heads: Tuple[int, int],
242
243
    head_size: int,
    dtype: torch.dtype,
244
    seed: int,
245
) -> None:
246
247
248
249
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

250
    seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
251
    seq_lens[-1] = MAX_SEQ_LEN
252
253
    num_tokens = sum(seq_lens)

254
    scale = float(1.0 / (head_size**0.5))
255
    num_query_heads, num_kv_heads = num_heads
256
    qkv = torch.empty(num_tokens,
257
                      num_query_heads + 2 * num_kv_heads,
258
259
                      head_size,
                      dtype=dtype,
260
261
262
263
264
265
266
267
268
269
                      device="cuda")
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
270
271
272
273
274
275
276
277
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
278
    )
279
    output = output.squeeze(0)
280

281
282
283
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
284
285
286
287
288
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
289
        scale,
290
291
        dtype,
    )
292
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)