benchmark_grouped_linear.py 12.4 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
#
# See LICENSE for license information.

import argparse
import torch
import torch.utils.benchmark as benchmark
import pandas as pd

from transformer_engine.pytorch.module import GroupedLinear
11
12
13
14
15
from transformer_engine.common.recipe import (
    Float8BlockScaling,
    MXFP8BlockScaling,
    NVFP4BlockScaling,
)
16
from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager
17
18
from contextlib import nullcontext

19
20
21
"""
# Profile BF16 recipe with Nsight Systems
nsys profile \
22
    --output=./benchmarks/linear/b200_numgemm_8_bf16 \
23
24
25
26
27
28
    --force-overwrite true \
    --trace=cuda,nvtx,cudnn,cublas \
    python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16

# Profile FP8 sub-channel recipe with Nsight Systems
nsys profile \
29
    --output=./benchmarks/linear/h100hbm_numgemm_8_fp8_sub_channel \
30
31
32
33
34
35
    --force-overwrite true \
    --trace=cuda,nvtx,cudnn,cublas \
    python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel

# Profile MXFP8 recipe with Nsight Systems
nsys profile \
36
    --output=./benchmarks/linear/b200_numgemm_8_mxfp8 \
37
38
39
40
    --force-overwrite true \
    --trace=cuda,nvtx,cudnn,cublas \
    python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8

41
42
43
44
45
46
47
# Profile NVFP4 recipe with Nsight Systems
nsys profile \
    --output=./benchmarks/linear/b200_numgemm_8_nvfp4 \
    --force-overwrite true \
    --trace=cuda,nvtx,cudnn,cublas \
    python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4

48
49
50
51
52
53
54
55
# Example for jagged input benchmark to simulate unbalanced token splits
python benchmarks/linear/benchmark_grouped_linear.py --recipe nvfp4 --jagged-input "15296,8960,14656,14784,11712,7936,14080,10880"

# Example to look at a single kernel target with NCU, like the fused hadamard amax kernel for NVFP4 recipe
ncu -f -o ./benchmarks/linear/ncu_b200_numgemm_8_nvfp4_rht_amax \
    --set=full \
    --kernel-name "GroupHadamardAmaxTmaKernel" \
    -s 5 -c 5 \
56
    python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4
57

58
59
"""

60
61
62
RECIPES = {
    "bf16": None,
    "fp8_sub_channel": Float8BlockScaling(),
63
    "mxfp8": MXFP8BlockScaling(),
64
    "nvfp4": NVFP4BlockScaling(),
65
66
}

67
68
69
70
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
71
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
72

73
74
75

def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
    assert mode in ["fwd_only", "fwd_bwd"]
76
77
78
    quantization_context = (
        autocast(enabled=True, recipe=recipe) if recipe is not None else nullcontext()
    )
79
80

    if mode == "fwd_only":
81
        with torch.no_grad(), quantization_context:
82
83
84
85
86
87
88
89
90
91
92
93
            for i in range(run_num_steps):
                y_q = layer.forward(
                    x,
                    m_splits,
                    is_first_microbatch=(i == 0),
                )
        return y_q
    else:
        # reset gradients
        layer.zero_grad()
        x.grad = None

94
        with quantization_context:
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            for i in range(run_num_steps):
                label = f"step_{i}"
                torch.cuda.nvtx.range_push(label)
                y_q = layer.forward(
                    x,
                    m_splits,
                    is_first_microbatch=(i == 0),
                )
                y_q.backward(gradient)
                torch.cuda.nvtx.range_pop()

        grads_q = []
        grads_q.append(x.grad)
        # remaining derivatives are in respect to model parameters
        for p in layer.parameters():
            if p.requires_grad:
                grads_q.append(p.grad)

        return y_q, grads_q


def benchmark_linear(
    x,
    ws,
    m_splits,
    bias,
    recipe_name,
    mode,
    num_gemms=4,
):
    params_dtype = torch.bfloat16
    recipe = RECIPES[recipe_name]

    in_features = x.shape[1]
    out_features = ws[0].shape[0]
    gradient = torch.ones((x.shape[0], out_features), dtype=torch.bfloat16, device=x.device)

    layer = GroupedLinear(
        num_gemms,
        in_features,
        out_features,
        bias=bias is not None,
        params_dtype=params_dtype,
    )

    layer = layer.to("cuda")
    with torch.no_grad():
        for i in range(num_gemms):
            weight_i = getattr(layer, f"weight{i}")
            weight_i.copy_(ws[i])
            if bias is not None:
                bias_i = getattr(layer, f"bias{i}")
                bias_i.copy_(bias)

    num_microbatches = 32

    label = f"{recipe_name}_{'grouped'}"
    torch.cuda.nvtx.range_push(label)
    timing = benchmark.Timer(
        stmt=(
            "run_linear_multiple_steps(layer, x, m_splits, mode, gradient, num_microbatches,"
            " recipe)"
        ),
        globals={
            "run_linear_multiple_steps": run_linear_multiple_steps,
            "layer": layer,
            "x": x,
            "m_splits": m_splits,
            "mode": mode,
            "gradient": gradient,
            "num_microbatches": num_microbatches,
            "recipe": recipe,
        },
        num_threads=1,
169
    ).blocked_autorange(min_run_time=10)
170
171
172
173
174
175
    print(f"{recipe_name}: {timing} \n")
    timing_ms = timing.median * 1000 / num_microbatches

    return timing_ms


176
177
178
def run_benchmark_linear(
    mkns, recipe_name, use_bias, num_gemms=4, m_splits_provided=None, fwd_only=False
):
179
180
181
182
183
184
185
186
    data = []
    assert not use_bias, "Bias is not supported for GroupedLinear benchmark"

    print(f"========== Benchmarking {recipe_name} ==========")
    for m, k, n in mkns:
        device = "cuda"
        x = torch.randn((m, k), dtype=torch.bfloat16, device=device, requires_grad=True)
        ws = [torch.randn((n, k), dtype=torch.bfloat16, device=device) for _ in range(num_gemms)]
187
        m_splits = [m // num_gemms] * num_gemms if m_splits_provided is None else m_splits_provided
188
189
190
191
192
        # Bias is not supported for GroupedLinear benchmark
        bias = None

        # Run the benchmark
        print(f"fwd_m={m}, fwd_k={k}, fwd_n={n}")
193
        print(f"m_splits: {m_splits}")
194
        print(f"fwd_only: {fwd_only}")
195
196
197
198
199
200
201

        grouped_fwd_bwd_timing_ms = benchmark_linear(
            x,
            ws,
            m_splits,
            bias,
            recipe_name,
202
            mode="fwd_only" if fwd_only else "fwd_bwd",
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            num_gemms=num_gemms,
        )

        # Append the results
        data.append(
            [
                m,
                k,
                n,
                recipe_name,
                num_gemms,
                grouped_fwd_bwd_timing_ms,
            ]
        )

218
219
    timing_notation = "grouped_fwd_time_ms" if fwd_only else "grouped_fwd_bwd_time_ms"

220
221
222
223
224
225
226
227
    df = pd.DataFrame(
        data=data,
        columns=[
            "m",
            "k",
            "n",
            "recipe",
            "num_gemms",
228
            timing_notation,
229
230
231
232
233
234
235
236
237
238
239
240
        ],
    )

    print(df, "\n")
    return df


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
    parser.add_argument(
241
        "--output-dir",
242
243
244
245
        type=str,
        default="benchmark_output/",
        help="output path for report",
    )
246
247
248
249
250
251
252
    # arguments for recipe, options are fp8_sub_channel, mxfp8, bf16, all
    parser.add_argument(
        "--recipe",
        type=str,
        default="bf16",
        help="Recipe to use, options are fp8_sub_channel, mxfp8, bf16, or all",
    )
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    # add an argument for the jagged input
    # example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880] => sums up to 98304
    parser.add_argument(
        "--jagged-input",
        type=str,
        default=None,
        help="Jagged input to use, example: [15296, 8960, 14656, 14784, 11712, 7936, 14080, 10880]",
    )
    parser.add_argument(
        "--hidden-dim",
        type=int,
        default=7168,
        help="Hidden dimension to use, default is 7168",
    )
    parser.add_argument(
        "--output-dim",
        type=int,
        default=2048,
        help="Output dimension to use, default is 2048",
    )
273
274
275
276
277
278
    parser.add_argument(
        "--fwd-only",
        action="store_true",
        default=False,
        help="Run forward pass only, default is both forward and backward passes",
    )
279
280
    args = parser.parse_args()

281
282
283
284
285
286
287
    jagged_input_splits = None
    if args.jagged_input is not None:
        jagged_input_splits = [int(x) for x in args.jagged_input.split(",")]
        print(f"Jagged input splits: {jagged_input_splits}")
        print(f"Jagged input splits sum: {sum(jagged_input_splits)}")
        print(f"Jagged input splits num_gemms: {len(jagged_input_splits)}")

288
289
    use_bias = False
    # Set the MKN values to benchmark
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    # Deepseek V3 EP64, SEQ_LEN=8192, topK8
    # 256 expert => 4 local experts
    # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 16384
    # M = AvgM * localExperts = 65536
    # K = 7168
    # N = 2048

    # Deepseek V3 EP32, SEQ_LEN=8192, topK8
    # 256 expert => 8 local experts
    # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 8192
    # M = AvgM * localExperts = 65536
    # K = 7168
    # N = 2048

    # 4 or 8local experts per rank
    num_gemms_list = [4, 8]

307
308
309
    if jagged_input_splits is not None:
        num_gemms_list = [len(jagged_input_splits)]

310
    token_dim_list = [16384, 32768, 65536, 98304]
311
312
313
314
315
316
317
318
319
320
321
322
323
    hidden_dim_list = [7168]
    output_dim_list = [2048]

    # override the default targets to benchmark if specified
    if jagged_input_splits is not None:
        token_dim_list = [sum(jagged_input_splits)]

    if args.hidden_dim is not None:
        hidden_dim_list = [args.hidden_dim]

    if args.output_dim is not None:
        output_dim_list = [args.output_dim]

324
    # MKN for group linear
325
    mkns = []
326
327
328
    for m in token_dim_list:
        for k in hidden_dim_list:
            for n in output_dim_list:
329
330
                mkns.append((m, k, n))

331
332
    # default recipes to run if not specified
    recipe_list = ["bf16"]
333

334
    if args.recipe == "all":
335
        recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"]
336
337
338
    else:
        recipe_list = [args.recipe]

339
    if args.profile:
340
341
342
343
344
345
346
347
        num_gemms_list = [8]
        hidden_dim_to_profile = 7168 if args.hidden_dim is None else args.hidden_dim
        output_dim_to_profile = 2048 if args.output_dim is None else args.output_dim
        token_dim_to_profile = 8192 * 8
        if jagged_input_splits is not None:
            num_gemms_list = [len(jagged_input_splits)]
            token_dim_to_profile = sum(jagged_input_splits)
        mkns = [(token_dim_to_profile, hidden_dim_to_profile, output_dim_to_profile)]
348
349
350
        # in profile mode, only run one recipe specified in args.recipe
        assert args.recipe != "all", (
            "In profile mode, only one recipe can be specified, please specify the recipe as"
351
            " fp8_sub_channel, mxfp8, nvfp4, or bf16"
352
353
        )
        recipe_list = [args.recipe]
354
355
356
357
358
359
360
361
362
        torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()

    # Initialize a dataframe to store the results
    df_linears = pd.DataFrame()

    # Run the fp8 benchmarks
    for num_gemms in num_gemms_list:
        print(f"========== Benchmarking with num_gemms={num_gemms} ==========")
        for recipe_name in recipe_list:
363
364
365
366
            assert recipe_name in [
                "bf16",
                "fp8_sub_channel",
                "mxfp8",
367
368
                "nvfp4",
            ], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4"
369
370
371
372
373
374
            if recipe_name == "mxfp8" and not mxfp8_available:
                print(f"MXFP8 is not available, skipping {recipe_name}")
                continue
            if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available:
                print(f"FP8 block scaling is not available, skipping {recipe_name}")
                continue
375
376
377
            if recipe_name == "nvfp4" and not nvfp4_available:
                print(f"NVFP4 is not available, skipping {recipe_name}")
                continue
378

379
380
381
382
383
            df = run_benchmark_linear(
                mkns,
                recipe_name,
                use_bias,
                num_gemms=num_gemms,
384
385
                m_splits_provided=jagged_input_splits,
                fwd_only=args.fwd_only,
386
387
388
389
390
391
392
            )
            df_linears = pd.concat([df_linears, df])

    print(df_linears)

    if args.profile:
        torch.autograd.profiler.emit_nvtx().__exit__(None, None, None)