benchmark_attention.py 4.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import functools
import random
import time
from typing import List

from flash_attn.flash_attn_interface import _flash_attn_forward
import torch

from cacheflow import attention_ops


def benchmark(name, f, num_warmup = 10, num_iters = 100):
    for _ in range(num_warmup):
        f()
    torch.cuda.synchronize()

    start = time.time()
    for _ in range(num_iters):
        f()
    torch.cuda.synchronize()
    end = time.time()
    print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')


@torch.inference_mode()
def benchmark_multi_query_cached_kv_attention(
    query_lens: List[int],
    context_lens: List[int],
    num_heads: int,
    head_size: int,
    block_size: int,
    num_blocks: int,
    dtype: torch.dtype,
) -> None:
    print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
          f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
          f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
    # Create query tensor.
    num_queries = len(query_lens)
    cu_query_lens = [0]
    for query_len in query_lens:
        cu_query_lens.append(cu_query_lens[-1] + query_len)
    num_total_tokens = cu_query_lens[-1]
    qkv = torch.randn(
        num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
    query, _, _ = qkv.unbind(dim=1)

    # Create key and value cache.
    x = 16 // torch.tensor([], dtype=dtype).element_size()
    key_block_shape = (num_heads, head_size // x, block_size, x)
    key_cache = torch.randn(
        size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
    value_block_shape = (num_heads, head_size, block_size)
    value_cache = torch.randn(
        size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')

    # Create block tables.
    max_context_len = max(context_lens)
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
    block_tables = []
    for _ in range(num_queries):
        block_table = [
            random.randint(0, num_blocks - 1)
            for _ in range(max_num_blocks_per_seq)
        ]
        block_tables.append(block_table)
    block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')

    # Create input and output data structures.
    cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
    context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
    scale = float(1.0 / (head_size ** 0.5))
    output = torch.empty(
        num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')

    # Run our implementation.
    def run_ours():
        attention_ops.multi_query_cached_kv_attention(
            cu_query_lens,
            output,
            query,
            key_cache,
            value_cache,
            scale,
            block_tables,
            context_len_tensor,
            block_size,
            max_context_len,
        )
    benchmark('Ours', run_ours)

    # Upper bound: Flash attention.
    # Becuase Flash attention cannot read our own cache,
    # we make key and value tensors contiguous.
    num_kv_tokens = sum(context_lens)
    cu_context_lens = [0]
    for context_len in context_lens:
        cu_context_lens.append(cu_context_lens[-1] + context_len)
    cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
    qkv = torch.randn(
        num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
    _, key, value = qkv.unbind(dim=1)
    ref_output = torch.empty_like(output)

    # Run Flash attention.
    def run_flash_attn():
        _flash_attn_forward(
            query,
            key,
            value,
            ref_output,
            cu_query_lens,
            cu_context_lens,
            max(query_lens),
            max_context_len,
            dropout_p=0.0,
            softmax_scale=scale,
            causal=True,
            return_softmax=False,
        )
    benchmark('Flash attention', run_flash_attn)


if __name__ == '__main__':
    BLOCK_SIZE = 8
    NUM_BLOCKS = 1024
    DTYPE = torch.half

    # LLaMA-13B and OPT-13B
    NUM_HEADS = 40
    HEAD_SIZE = 128

    run_benchmark = functools.partial(
        benchmark_multi_query_cached_kv_attention,
        num_heads=NUM_HEADS,
        head_size=HEAD_SIZE,
        block_size=BLOCK_SIZE,
        num_blocks=NUM_BLOCKS,
        dtype=DTYPE,
    )

    run_benchmark(
        query_lens=[64] * 1,
        context_lens=[64] * 1,
    )
    run_benchmark(
        query_lens=[128] * 1,
        context_lens=[128] * 1,
    )
    run_benchmark(
        query_lens=[64] * 8,
        context_lens=[64] * 8,
    )
    run_benchmark(
        query_lens=[128] * 8,
        context_lens=[128] * 8,
    )
    run_benchmark(
        query_lens=[64, 32, 16],
        context_lens=[128, 256, 64],
    )
    run_benchmark(
        query_lens=[1024],
        context_lens=[1024],
    )