bench_rotary_embedding.py 2.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import itertools

import torch
import triton
from sgl_kernel import FusedSetKVBufferArg
from sgl_kernel.testing.rotary_embedding import (
    FlashInferRotaryEmbedding,
    MHATokenToKVPool,
    RotaryEmbedding,
    create_inputs,
)

from sglang.srt.bench_utils import bench_kineto

configs = [
    (batch_size, seq_len, save_kv_cache)
    for batch_size, seq_len in (
        (1, 1),
        (32, 1),
        (128, 1),
        (512, 1),
        (2, 512),
        (4, 4096),
    )
    for save_kv_cache in (False, True)
]


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size", "seq_len", "save_kv_cache"],
        x_vals=configs,
        line_arg="provider",
        line_vals=["sglang"],
        line_names=["SGL Kernel"],
        styles=[("green", "-")],
        ylabel="us",
        plot_name="bench_rotary_embedding",
        args={},
    )
)
def benchmark(batch_size, seq_len, save_kv_cache, provider):
    device = torch.device("cuda")

    num_q_heads = 32
    num_kv_heads = 8
    head_size = 64
    dtype = torch.bfloat16

    config = dict(
        head_size=head_size,
        rotary_dim=64,
        max_position_embeddings=4096,
        base=8000,
        is_neox_style=True,
        dtype=dtype,
    )
    rope_flashinfer = FlashInferRotaryEmbedding(**config).to(device)
    pool_flashinfer = MHATokenToKVPool(head_num=num_kv_heads, head_dim=head_size)

    inputs = create_inputs(
        head_size=head_size,
        batch_size=batch_size,
        seq_len=seq_len,
        device=device,
        dtype=dtype,
        num_q_heads=num_q_heads,
        num_kv_heads=num_kv_heads,
    )

    query_flashinfer, key_flashinfer = inputs["query"].clone(), inputs["key"].clone()

    bench_fn = lambda: rope_flashinfer.forward_cuda(
        inputs["pos_ids"],
        query_flashinfer,
        key_flashinfer,
        fused_set_kv_buffer_arg=(
            FusedSetKVBufferArg(
                value=inputs["value"],
                k_buffer=pool_flashinfer.k_buffer[0].view(-1, num_kv_heads * head_size),
                v_buffer=pool_flashinfer.v_buffer[0].view(-1, num_kv_heads * head_size),
                k_scale=None,
                v_scale=None,
                cache_loc=inputs["out_cache_loc"],
            )
            if save_kv_cache
            else None
        ),
    )

    time_s = bench_kineto(bench_fn, kernel_names="BatchQKApplyRotaryPosIds")
    return time_s * 1e6


if __name__ == "__main__":
    benchmark.run(print_data=True)