bench_concat_mla_q.py 2.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse

import torch

from vllm import _custom_ops as ops
from vllm.triton_utils import triton

# DeepSeek V3 dimensions
NOPE_DIM = 512
ROPE_DIM = 64
NUM_HEADS = 128

NUM_TOKENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]


def get_configs():
    return NUM_TOKENS


def make_inputs(num_tokens, dtype):
    """Create inputs matching the real code path.

    Args:
        contiguous_nope: If False, simulate the transposed BMM output
                         (non-contiguous nope with stride pattern from
                         [N,B,L].transpose(0,1)).
    """
    # Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L]
    raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, device="cuda")
    ql_nope = raw.transpose(0, 1)

    q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, device="cuda")
    return ql_nope, q_pe


# ---- Non-contiguous nope benchmark (real code path) ----
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["num_tokens"],
        x_vals=get_configs(),
        line_arg="provider",
        line_vals=["torch_cat", "concat_mla_q"],
        line_names=["torch.cat", "concat_mla_q (v8)"],
        styles=[("blue", "--"), ("green", "-")],
        ylabel="Latency (us)",
        plot_name="concat_mla_q-transposed",
        args={},
    )
)
def bench_transposed(num_tokens, provider):
    dtype = torch.bfloat16
    ql_nope, q_pe = make_inputs(num_tokens, dtype)

    q_out = torch.empty(
        num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, dtype=dtype, device="cuda"
    )

    quantiles = [0.5, 0.2, 0.8]

    if provider == "torch_cat":
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
            lambda: torch.cat((ql_nope, q_pe), dim=-1), quantiles=quantiles, rep=500
        )
    else:
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
            lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), quantiles=quantiles, rep=500
        )

    return ms * 1000, max_ms * 1000, min_ms * 1000  # us


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark concat_mla_q vs torch.cat")
    parser.add_argument(
        "--save-path", type=str, default=None, help="Path to save benchmark results"
    )
    args = parser.parse_args()

    print("\n" + "=" * 70)
    print("CONCAT MLA Q KERNEL BENCHMARKS")
    print("=" * 70)
    print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}")
    print(
        f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = "
        f"{(NOPE_DIM + ROPE_DIM) * 2} bytes"
    )
    print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}")
    print("=" * 70)

    print("\n--- Non-contiguous nope inputs (transposed BMM output) ---")
    bench_transposed.run(print_data=True, save_path=args.save_path)

    print("\n" + "=" * 70)
    print("Benchmarking complete!")
    print("=" * 70)