test_attention.py 9.1 KB
Newer Older
1
import random
2
from typing import List, Optional, Tuple
3

4
import pytest
5
import torch
6
7
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm import attention_ops
10

11
12
13
14
15
16
17
18
19
MAX_SEQ_LEN = 8192
NUM_BLOCKS = 128  # Arbitrary values for testing

DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7]  # Arbitrary values for testing
NUM_PREFILL_SEQS = [1, 3, 7]  # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
20
USE_ALIBI = [False, True]
21
SEEDS = [0]
22

23
24
25
26
27
28
29
30

def ref_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
31
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
32
    if attn_mask is not None:
33
34
35
        attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
36
37
38
39
40
41
    return out


def ref_single_query_cached_kv_attention(
    output: torch.Tensor,
    query: torch.Tensor,
42
    num_queries_per_kv: int,
43
44
45
46
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    block_tables: torch.Tensor,
    context_lens: torch.Tensor,
47
48
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
49
) -> None:
50
51
    num_query_heads = query.shape[1]
    num_kv_heads = value_cache.shape[1]
52
53
    head_size = value_cache.shape[2]
    block_size = value_cache.shape[3]
54
    num_seqs = query.shape[0]
55

56
57
58
    block_tables = block_tables.cpu().tolist()
    context_lens = context_lens.cpu().tolist()
    for i in range(num_seqs):
59
60
61
62
63
64
65
66
67
68
69
        q = query[i].unsqueeze(0)
        block_table = block_tables[i]
        context_len = int(context_lens[i])

        keys = []
        values = []
        for j in range(context_len):
            block_number = int(block_table[j // block_size])
            block_offset = j % block_size

            k = key_cache[block_number, :, :, block_offset, :]
70
            k = k.reshape(num_kv_heads, head_size)
71
72
73
74
75
76
            keys.append(k)

            v = value_cache[block_number, :, :, block_offset]
            values.append(v)
        keys = torch.stack(keys, dim=0)
        values = torch.stack(values, dim=0)
77
78
79
80
81
82
83
84
85
        if num_queries_per_kv > 1:
            # Handle MQA and GQA
            keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
            values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)

        alibi_bias = None
        if alibi_slopes is not None:
            # Create the ALiBi bias used in the paged attention kernel.
            position_ids = torch.arange(context_len, device="cuda").int()
86
            alibi_bias = (position_ids - context_len + 1).float()
87
88
89
90
91
            alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                1, 1, -1)

        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
        out = out.view(num_query_heads, head_size)
92
93
94
        output[i].copy_(out, non_blocking=True)


95
96
97
98
99
100
101
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
102
@torch.inference_mode()
103
104
105
106
def test_single_query_cached_kv_attention(
    kv_cache_factory,
    num_seqs: int,
    num_heads: Tuple[int, int],
107
    head_size: int,
108
    use_alibi: bool,
109
110
    block_size: int,
    dtype: torch.dtype,
111
    seed: int,
112
) -> None:
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    scale = float(1.0 / (head_size**0.5))
    num_query_heads, num_kv_heads = num_heads
    query = torch.empty(num_seqs,
                        num_query_heads,
                        head_size,
                        dtype=dtype,
                        device="cuda")
    query.uniform_(-scale, scale)

    assert num_query_heads % num_kv_heads == 0
    num_queries_per_kv = num_query_heads // num_kv_heads
    head_mapping = torch.repeat_interleave(
        torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
        num_queries_per_kv)
    alibi_slopes = None
    if use_alibi:
        alibi_slopes = torch.randn(num_query_heads,
                                   dtype=torch.float,
                                   device="cuda")

    context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
138
    max_context_len = max(context_lens)
139
    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
140

141
    # Create the block tables.
142
143
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
144
    for _ in range(num_seqs):
145
        block_table = [
146
            random.randint(0, NUM_BLOCKS - 1)
147
148
149
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
150
    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
151

152
153
154
155
156
    # Create the KV caches.
    key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
                                                num_kv_heads, head_size, dtype,
                                                seed)
    key_cache, value_cache = key_caches[0], value_caches[0]
Tao Peng's avatar
Tao Peng committed
157

158
159
    # Call the paged attention kernel.
    output = torch.empty_like(query)
160
161
162
163
164
    attention_ops.single_query_cached_kv_attention(
        output,
        query,
        key_cache,
        value_cache,
165
        head_mapping,
166
167
168
169
170
        scale,
        block_tables,
        context_lens,
        block_size,
        max_context_len,
171
        alibi_slopes,
172
173
    )

174
    # Run the reference implementation.
175
176
177
178
    ref_output = torch.empty_like(query)
    ref_single_query_cached_kv_attention(
        ref_output,
        query,
179
        num_queries_per_kv,
180
181
182
183
        key_cache,
        value_cache,
        block_tables,
        context_lens,
184
185
        scale,
        alibi_slopes,
186
    )
187
188
189
190

    # NOTE(woosuk): Due to the kernel-level differences in the two
    # implementations, there is a small numerical difference in the two
    # outputs. Thus, we use a relaxed tolerance for the test.
191
192
193
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def ref_multi_query_kv_attention(
    cu_seq_lens: List[int],
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
    dtype: torch.dtype,
) -> torch.Tensor:
    num_seqs = len(cu_seq_lens) - 1
    ref_outputs = []
    for i in range(num_seqs):
        start_idx = cu_seq_lens[i]
        end_idx = cu_seq_lens[i + 1]
        seq_len = end_idx - start_idx

        # Create attention mask.
        attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(dtype).min
        attn_mask = attn_mask.to(dtype=dtype, device="cuda")

        ref_output = ref_masked_attention(
            query[start_idx:end_idx],
            key[start_idx:end_idx],
            value[start_idx:end_idx],
            scale,
            attn_mask=attn_mask,
        )
        ref_outputs.append(ref_output)
    ref_output = torch.cat(ref_outputs, dim=0)
    return ref_output


227
# TODO(woosuk): Add tests for USE_ALIBI=True.
228
229
230
231
232
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
233
@torch.inference_mode()
234
def test_multi_query_kv_attention(
235
    num_seqs: int,
236
    num_heads: Tuple[int, int],
237
238
    head_size: int,
    dtype: torch.dtype,
239
    seed: int,
240
) -> None:
241
242
243
244
    random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed(seed)

245
246
247
    seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
    num_tokens = sum(seq_lens)

248
    scale = float(1.0 / (head_size**0.5))
249
    num_query_heads, num_kv_heads = num_heads
250
    qkv = torch.empty(num_tokens,
251
                      num_query_heads + 2 * num_kv_heads,
252
253
                      head_size,
                      dtype=dtype,
254
255
256
257
258
259
260
261
262
263
                      device="cuda")
    qkv.uniform_(-scale, scale)
    query, key, value = qkv.split(
        [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

    num_queries_per_kv = num_query_heads // num_kv_heads
    if num_queries_per_kv > 1:
        # Handle MQA and GQA
        key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
        value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
264
265
266
267
268
269
270
271
    attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
    output = xops.memory_efficient_attention_forward(
        query.unsqueeze(0),
        key.unsqueeze(0),
        value.unsqueeze(0),
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
272
    )
273
    output = output.squeeze(0)
274

275
276
277
    cu_seq_lens = [0]
    for seq_len in seq_lens:
        cu_seq_lens.append(cu_seq_lens[-1] + seq_len)
278
279
280
281
282
    ref_output = ref_multi_query_kv_attention(
        cu_seq_lens,
        query,
        key,
        value,
283
        scale,
284
285
        dtype,
    )
286
    assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)