test_prefix_prefill.py 7.03 KB
Newer Older
1
2
3
import random
import time

4
import pytest
5
6
7
8
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

9
10
from vllm.attention.ops.prefix_prefill import context_attention_fwd

11
12
NUM_HEADS = [64]
NUM_QUERIES_PER_KV = [1, 8, 64]
13
14
HEAD_SIZES = [128]
DTYPES = [torch.float16]
15
16
17
CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
18
19
20


@pytest.mark.parametrize("num_heads", NUM_HEADS)
21
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
22
23
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
24
@pytest.mark.parametrize("device", CUDA_DEVICES)
25
26
27
@torch.inference_mode()
def test_contexted_kv_attention(
    num_heads: int,
28
    num_queries_per_kv: int,
29
30
    head_size: int,
    dtype: torch.dtype,
31
    device: str,
32
33
34
) -> None:
    random.seed(0)
    torch.manual_seed(0)
35
36
37
    if torch.cuda.is_available():
        torch.cuda.manual_seed(0)
    torch.set_default_device(device)
38

39
40
    # Need this, otherwise when we capture the graph the process
    # for GPU 1 would run on both GPU0 and GPU1 and things would hang
41
42
43
44
    #
    # see also similar issue: https://github.com/Dao-AILab/flash-attention/issues/523
    torch.cuda.set_device(device)

45
46
47
48
49
50
51
52
53
    MAX_SEQ_LEN = 1024
    MAX_CTX_LEN = 1024
    BS = 10
    cache_size = 640
    block_size = 32
    max_block_per_request = 64
    subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
    ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
    seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
54
    num_kv_heads = num_heads // num_queries_per_kv
55
56

    num_tokens = sum(subquery_lens)
57
    query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
58
    query.uniform_(-1e-3, 1e-3)
59
    output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
60

61
    kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype)
62
63
64
65
66
    kv.uniform_(-1e-3, 1e-3)
    key, value = kv.unbind(dim=1)

    k_cache = torch.zeros(cache_size,
                          block_size,
67
                          num_kv_heads,
68
                          head_size,
69
                          dtype=dtype)
70
71
    v_cache = torch.zeros(cache_size,
                          block_size,
72
                          num_kv_heads,
73
                          head_size,
74
                          dtype=dtype)
75
76
    k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
    v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype)
77
    values = torch.arange(0, cache_size, dtype=torch.long)
78
79
80
    values = values[torch.randperm(cache_size)]
    block_table = values[:BS * max_block_per_request].view(
        BS, max_block_per_request)
81
82
    b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
    b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
83
    b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
84
                                            dtype=torch.long),
85
86
87
88
                               dim=0)
    max_input_len = MAX_SEQ_LEN
    # copy kv to cache
    b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
89
                                                dtype=torch.long),
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
                                   dim=0)
    for i in range(BS):
        for j in range(subquery_lens[i]):
            k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
                                            j])
            v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
                                              b_ctx_len[i] + j])
        cur_ctx = 0
        block_id = 0
        while cur_ctx < b_ctx_len[i]:
            start_loc = b_seq_start_loc[i] + cur_ctx
            if cur_ctx + block_size > b_ctx_len[i]:
                end_loc = b_seq_start_loc[i] + b_ctx_len[i]
            else:
                end_loc = start_loc + block_size
            start_slot = block_table[i, block_id] * block_size
            end_slot = start_slot + end_loc - start_loc
107
108
109
110
111
112
            k_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             key[start_loc:end_loc])
            v_cache.view(-1, num_kv_heads,
                         head_size)[start_slot:end_slot].copy_(
                             value[start_loc:end_loc])
113
114
115
116
            cur_ctx += block_size
            block_id += 1
    # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
117
    k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
118
119
120
                           8).permute(0, 2, 3, 1, 4).contiguous()
    # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
    # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
121
    v_cache = v_cache.view(-1, block_size, num_kv_heads,
122
123
                           head_size).permute(0, 2, 3, 1).contiguous()

124
125
    # Warm up the Triton kernel by calling it once before actually measuring
    # generation time
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
                          b_start_loc, b_seq_len, b_ctx_len, max_input_len)
    torch.cuda.synchronize()
    start_time = time.time()
    context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table,
                          b_start_loc, b_seq_len, b_ctx_len, max_input_len)
    torch.cuda.synchronize()
    end_time = time.time()
    print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")

    scale = float(1.0 / (head_size**0.5))

    attn_op = xops.fmha.cutlass.FwOp()

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    if num_kv_heads != num_heads:
        # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
        # project the key and value tensors to the desired number of
        # heads.
        #
        # see also: vllm/model_executor/layers/attention.py
        query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
                           query.shape[-1])
        key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
                                        num_queries_per_kv, key.shape[-1])
        value = value[:, :,
                      None, :].expand(value.shape[0], num_kv_heads,
                                      num_queries_per_kv, value.shape[-1])
    query = query.unsqueeze(0)
    key = key.unsqueeze(0)
    value = value.unsqueeze(0)

157
158
159
    attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
        subquery_lens, seq_lens)
    output_ref = xops.memory_efficient_attention_forward(
160
161
162
        query,
        key,
        value,
163
164
165
166
167
168
169
170
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
        op=attn_op,
    )
    torch.cuda.synchronize()
    start_time = time.time()
    output_ref = xops.memory_efficient_attention_forward(
171
172
173
        query,
        key,
        value,
174
175
176
177
178
179
180
181
        attn_bias=attn_bias,
        p=0.0,
        scale=scale,
        op=attn_op,
    )
    torch.cuda.synchronize()
    end_time = time.time()
    print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
182
    output_ref = output_ref.reshape(output.shape)
183
    assert torch.allclose(output_ref, output, atol=1e-6, rtol=0)