benchmark_attnmask.py 14.2 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
#!/usr/bin/env python
# Benchmark: 不同 size 下 FlashAttention 无 attnmask vs 有 attnmask 的延时与速度比。
#
# 直接运行(无参数)一次性输出 4 张表:fwd causal=True、fwd causal=False、bwd causal=True、bwd causal=False
#   python benchmarks/benchmark_attnmask.py
# 仅 forward:python benchmarks/benchmark_attnmask.py --no-backward
# 仅 causal=True:python benchmarks/benchmark_attnmask.py --no-causal --causal  (或只 --no-both-causal)
# 详细对比(非表格):python benchmarks/benchmark_attnmask.py --no-table

import argparse
import sys

# 需要与常见 benchmark 表格同尺寸时,可传:--sizes "1,1024 1,2048 1,4096 1,8192 1,16384 1,32768 8,1024 ..."
import math
import torch

from flash_attn import flash_attn_func, flash_attn_with_mask_func
from flash_attn.utils.benchmark import benchmark_forward, benchmark_fwd_bwd


def flops(batch, seqlen, headdim, nheads, causal, mode="fwd"):
    """FLOPs 与 benchmark_flash_attention.py / fa_bwd_benchmark.py 一致。
    fwd: 4*B*S²*H*d // (2 if causal else 1);bwd: 2.5*f;fwd_bwd: 3.5*f。
    """
    assert mode in ["fwd", "bwd", "fwd_bwd"]
    f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1)
    return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f)


def efficiency(flop, time_sec):
    """TFLOPs/s = flop / time_sec / 1e12,与 benchmark_flash_attention / fa_bwd_benchmark 一致。"""
    return (flop / time_sec / 10**12) if not math.isnan(time_sec) and time_sec > 0 else 0.0


def attn_mask_bytes(batch, nheads_q, seqlen):
    """attn_mask (batch, nheads_q, seqlen, seqlen) bool 的字节数。"""
    return batch * nheads_q * seqlen * seqlen  # 1 byte per bool


def _time_forward_ms(fn, *args, repeats=30, **kwargs):
    _, m = benchmark_forward(fn, *args, repeats=repeats, verbose=False, **kwargs)
    return m.mean * 1000.0


def _time_fwd_bwd_ms(fn, *args, repeats=30, **kwargs):
    (_, m_fwd), (_, m_bwd) = benchmark_fwd_bwd(fn, *args, repeats=repeats, verbose=False, **kwargs)
    return m_fwd.mean * 1000.0, m_bwd.mean * 1000.0


def main():
    parser = argparse.ArgumentParser(description="Benchmark: 无 attnmask vs 有 attnmask 延时与速度比。默认直接打表。")
    parser.add_argument("--table", action="store_true", default=True, help="打印表格(默认开启)")
    parser.add_argument("--no-table", action="store_false", dest="table", help="不打印表格,改为详细对比格式")
    parser.add_argument("--batch", type=int, nargs="+", default=[128], help="batch sizes(未指定 --sizes 时)")
    parser.add_argument("--seqlen", type=int, nargs="+", default=[512, 1024, 1280, 1536, 2048], help="sequence lengths(未指定 --sizes 时)")
    parser.add_argument("--sizes", type=str, default=None, help="(batch,seqlen) 对,空格分隔;不传则用 --batch 与 --seqlen 的笛卡尔积")
    parser.add_argument("--nheads", type=int, default=28, help="nheads_q 默认值(未指定 --nheads-q 时)")
    parser.add_argument("--nheads-q", type=int, default=None, help="query 头数,默认 28")
    parser.add_argument("--num-heads-kv", type=int, default=4, help="kv 头数,默认 4(GQA)")
    parser.add_argument("--headdim", type=int, nargs="+", default=[64, 128], help="head 维度,默认 64,128")
    parser.add_argument("--repeats", type=int, default=30)
    parser.add_argument("--causal", action="store_true", default=True, help="causal=True(默认)")
    parser.add_argument("--no-causal", action="store_false", dest="causal", help="causal=False")
    parser.add_argument("--both-causal", action="store_true", default=True, help="同时跑 causal True 与 False(默认开启,无参时出 4 张表)")
    parser.add_argument("--no-both-causal", action="store_false", dest="both_causal", help="只跑当前 --causal 一种")
    parser.add_argument("--backward", action="store_true", default=True, help="是否测 backward(默认开启,无参时出 4 张表)")
    parser.add_argument("--no-backward", action="store_false", dest="backward")
    parser.add_argument("--dtype", choices=["fp16", "bf16"], default="fp16")
    parser.add_argument("--max-mask-gb", type=float, default=24.0, help="attn_mask 显存超过此值(GiB)时跳过该尺寸,避免 OOM;0 表示不限制")
    args = parser.parse_args()

    nheads_q = args.nheads_q if args.nheads_q is not None else args.nheads
    num_heads_kv = args.num_heads_kv if args.num_heads_kv is not None else nheads_q
    assert nheads_q % num_heads_kv == 0, "nheads_q must be divisible by num_heads_kv (GQA)"

    device = "cuda"
    dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
    dtype_str = "float16" if args.dtype == "fp16" else "bfloat16"
    if args.sizes:
        batch_sizes, seqlens = [], []
        for pair in args.sizes.split():
            b, s = pair.split(",")
            batch_sizes.append(int(b))
            seqlens.append(int(s))
        size_pairs = list(zip(batch_sizes, seqlens))
    else:
        size_pairs = None
        batch_sizes = args.batch
        seqlens = args.seqlen
    headdims = args.headdim
    repeats = args.repeats
    causal_vals = [True, False] if args.both_causal else [args.causal]
    fwd_header = "batch_size\tseqlen\tseqlen\tnheads_q\tnum_heads_kv\tcausal\tdim\tdtype\ttflops_attnmask_fwd\ttime_attnmask_fwd(ms)\ttflops_no_fwd\ttime_no_fwd(ms)\tfwd(%)"
    bwd_header = "batch_size\tseqlen\tseqlen\tnheads_q\tnum_heads_kv\tcausal\tdim\tdtype\ttflops_attnmask_bwd\ttime_attnmask_bwd(ms)\ttflops_no_bwd\ttime_no_bwd(ms)\tbwd(%)"

    for headdim in headdims:
        run_bwd = args.backward and headdim in (64, 128)
        if args.table:
            print(f"\n=== dim={headdim} ===", flush=True)
        for causal in causal_vals:
            rows_bwd = []
            if args.table:
                if run_bwd:
                    print(fwd_header, flush=True)
                else:
                    print("batch_size\tseqlen\tseqlen\tnheads_q\tnum_heads_kv\tcausal\tdim\tdtype\ttflops_attnmask\ttime_attnmask(ms)\ttflops_no\ttime_no(ms)\ttflops_attnmask/no_attnmask(%)", flush=True)
            else:
                print("\n" + "=" * 90)
                print("Benchmark: 无 attnmask vs 有 attnmask — 各 size 延时 (ms) 与速度比 (attnmask/no_attnmask)")
                print("=" * 90)
                print(f"  dtype={args.dtype}, nheads_q={nheads_q}, num_heads_kv={num_heads_kv}, headdim={headdim}, causal={causal}, repeats={repeats}")
                if args.backward and headdim not in (64, 128):
                    print("  backward 对比仅在 headdim=64/128 时执行,当前 dim 只统计 forward。")
                if run_bwd:
                    print(f"  {'batch':>5} {'seqlen':>7}{'no_attnmask_fwd':>12} {'attnmask_fwd':>12} {'ratio_fwd':>9} │ "
                          f"{'no_attnmask_bwd':>12} {'attnmask_bwd':>12} {'ratio_bwd':>9}")
                else:
                    print(f"  {'batch':>5} {'seqlen':>7}{'no_attnmask(ms)':>14} {'attnmask(ms)':>14}{'speed_ratio':>10}  (attnmask/no_attnmask, >1 表示 attnmask 更慢)")
                print("-" * 90)

            for batch, seqlen in (size_pairs if size_pairs else ((b, s) for b in batch_sizes for s in seqlens)):
                    mask_gb = attn_mask_bytes(batch, nheads_q, seqlen) / (1024**3)
                    if args.max_mask_gb > 0 and mask_gb > args.max_mask_gb:
                        if args.table:
                            skip_row = f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\t-\t-\tskip(OOM)\t{mask_gb:.1f}GiB_mask\t-"
                            print(skip_row, flush=True)
                            if run_bwd:
                                rows_bwd.append(skip_row)
                        else:
                            print(f"  {batch:>5} {seqlen:>7} │ skip (attn_mask 约 {mask_gb:.1f} GiB > --max-mask-gb {args.max_mask_gb})")
                        continue
                    try:
                        q = torch.randn(batch, seqlen, nheads_q, headdim, dtype=dtype, device=device)
                        k = torch.randn(batch, seqlen, num_heads_kv, headdim, dtype=dtype, device=device)
                        v = torch.randn(batch, seqlen, num_heads_kv, headdim, dtype=dtype, device=device)
                        attn_mask = torch.ones(batch, nheads_q, seqlen, seqlen, dtype=torch.bool, device=device)
                    except torch.cuda.OutOfMemoryError:
                        if args.table:
                            oom_row = f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\tOOM\t-\tOOM\t-\t-"
                            print(oom_row, flush=True)
                            if run_bwd:
                                rows_bwd.append(oom_row)
                        else:
                            print(f"  {batch:>5} {seqlen:>7} │ OOM (attn_mask 约 {mask_gb:.1f} GiB)")
                        torch.cuda.empty_cache()
                        continue

                    try:
                        t_no = _time_forward_ms(flash_attn_func, q, k, v, causal=causal, repeats=repeats)
                        t_mask = _time_forward_ms(flash_attn_with_mask_func, q, k, v, attn_mask, causal=causal, repeats=repeats)
                    except torch.cuda.OutOfMemoryError:
                        if args.table:
                            oom_row = f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\tOOM\t-\tOOM\t-\t-"
                            print(oom_row, flush=True)
                            if run_bwd:
                                rows_bwd.append(oom_row)
                        else:
                            print(f"  {batch:>5} {seqlen:>7} │ OOM (forward)")
                        del q, k, v, attn_mask
                        torch.cuda.empty_cache()
                        continue

                    ratio_fwd = t_mask / t_no if t_no > 0 else 0.0

                    if args.table:
                        flop_fwd = flops(batch, seqlen, headdim, nheads_q, causal, mode="fwd")
                        tflops_no_fwd = efficiency(flop_fwd, t_no / 1000.0)
                        tflops_attnmask_fwd = efficiency(flop_fwd, t_mask / 1000.0)
                        fwd_pct = (tflops_attnmask_fwd / tflops_no_fwd * 100.0) if tflops_no_fwd > 0 else 0.0
                        if run_bwd:
                            q.requires_grad_(True)
                            k.requires_grad_(True)
                            v.requires_grad_(True)
                            try:
                                (no_fwd, no_bwd) = _time_fwd_bwd_ms(flash_attn_func, q, k, v, causal=causal, repeats=repeats)
                                q2 = q.detach().clone().requires_grad_(True)
                                k2 = k.detach().clone().requires_grad_(True)
                                v2 = v.detach().clone().requires_grad_(True)
                                (mask_fwd, mask_bwd) = _time_fwd_bwd_ms(
                                    flash_attn_with_mask_func, q2, k2, v2, attn_mask, causal=causal, repeats=repeats
                                )
                            except torch.cuda.OutOfMemoryError:
                                print(f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\t{tflops_attnmask_fwd:.2f}\t{t_mask:.2f}\t{tflops_no_fwd:.2f}\t{t_no:.2f}\t{fwd_pct:.1f}%", flush=True)
                                rows_bwd.append(f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\tOOM\t-\tOOM\t-\t-")
                                torch.cuda.empty_cache()
                                continue
                            flop_bwd = flops(batch, seqlen, headdim, nheads_q, causal, mode="bwd")
                            tflops_no_bwd = efficiency(flop_bwd, no_bwd / 1000.0)
                            tflops_attnmask_bwd = efficiency(flop_bwd, mask_bwd / 1000.0)
                            bwd_pct = (tflops_attnmask_bwd / tflops_no_bwd * 100.0) if tflops_no_bwd > 0 else 0.0
                            print(f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\t{tflops_attnmask_fwd:.2f}\t{mask_fwd:.2f}\t{tflops_no_fwd:.2f}\t{no_fwd:.2f}\t{fwd_pct:.1f}%", flush=True)
                            rows_bwd.append(f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\t{tflops_attnmask_bwd:.2f}\t{mask_bwd:.2f}\t{tflops_no_bwd:.2f}\t{no_bwd:.2f}\t{bwd_pct:.1f}%")
                        else:
                            print(f"{batch}\t{seqlen}\t{seqlen}\t{nheads_q}\t{num_heads_kv}\t{causal}\t{headdim}\t{dtype_str}\t{tflops_attnmask_fwd:.2f}\t{t_mask:.2f}\t{tflops_no_fwd:.2f}\t{t_no:.2f}\t{fwd_pct:.1f}%", flush=True)
                        continue

                    if run_bwd:
                        q.requires_grad_(True)
                        k.requires_grad_(True)
                        v.requires_grad_(True)
                        (no_fwd, no_bwd) = _time_fwd_bwd_ms(flash_attn_func, q, k, v, causal=causal, repeats=repeats)
                        q2 = q.detach().clone().requires_grad_(True)
                        k2 = k.detach().clone().requires_grad_(True)
                        v2 = v.detach().clone().requires_grad_(True)
                        (mask_fwd, mask_bwd) = _time_fwd_bwd_ms(
                            flash_attn_with_mask_func, q2, k2, v2, attn_mask, causal=causal, repeats=repeats
                        )
                        ratio_fwd = mask_fwd / no_fwd if no_fwd > 0 else 0.0
                        ratio_bwd = mask_bwd / no_bwd if no_bwd > 0 else 0.0
                        print(f"  {batch:>5} {seqlen:>7}{no_fwd:>12.3f} {mask_fwd:>12.3f} {ratio_fwd:>8.2f}x │ "
                              f"{no_bwd:>12.3f} {mask_bwd:>12.3f} {ratio_bwd:>8.2f}x")
                    else:
                        print(f"  {batch:>5} {seqlen:>7}{t_no:>14.3f} {t_mask:>14.3f}{ratio_fwd:>9.2f}x")

            if args.table and run_bwd and rows_bwd:
                print(bwd_header, flush=True)
                for r in rows_bwd:
                    print(r, flush=True)

            if not args.table:
                print("=" * 90)
                print("speed_ratio = attnmask_time / no_attnmask_time  (>1 表示 attnmask 更慢)")
                print()

if __name__ == "__main__":
    main()