"examples/vision/python_super_resolution/cat.jpg" did not exist on "78eaf2b80d39277a59ff600573949740439259d3"
benchmark_fbgemm_grouped_gemm.py 16.6 KB
Newer Older
1
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
2
3
4
5
import argparse

import torch
import triton
6
7
8
9
10
11
12
13
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
    quantize_fp8_row,
    triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
    grouped_gemm as fbgemm_grouped_gemm,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
)
from transformers import AutoConfig

from sglang.srt.layers.moe.ep_moe.kernels import (
    grouped_gemm_triton as sglang_grouped_gemm,
)


def get_model_config(model_name: str, tp_size: int):
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

    if config.architectures[0] == "DbrxForCausalLM":
        num_groups = config.ffn_config.moe_num_experts
        intermediate_size = config.ffn_config.ffn_hidden_size
    elif config.architectures[0] == "JambaForCausalLM":
        num_groups = config.num_experts
        intermediate_size = config.intermediate_size
    elif config.architectures[0] == "Qwen2MoeForCausalLM":
        num_groups = config.num_experts
        intermediate_size = config.moe_intermediate_size
    elif config.architectures[0] == "Qwen3MoeForCausalLM":
        num_groups = config.num_experts
        intermediate_size = config.moe_intermediate_size
38
39
40
41
42
    elif config.architectures[0] in [
        "DeepseekV2ForCausalLM",
        "DeepseekV3ForCausalLM",
    ]:
        num_groups = config.n_routed_experts
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        intermediate_size = config.moe_intermediate_size
    elif config.architectures[0] == "Llama4ForConditionalGeneration":
        num_groups = config.text_config.num_local_experts
        intermediate_size = config.text_config.intermediate_size
    elif config.architectures[0] in [
        "Grok1ForCausalLM",
        "Grok1ImgGen",
        "Grok1AForCausalLM",
    ]:
        num_groups = config.num_local_experts
        intermediate_size = config.moe_intermediate_size
    else:
        num_groups = config.num_local_experts
        intermediate_size = config.intermediate_size

    shape_configs = {
        "num_groups": num_groups,
        "hidden_size": config.hidden_size,
        "intermediate_size": intermediate_size,
        "dtype": config.torch_dtype,
    }
    print(f"{shape_configs=}")
    return shape_configs


def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
    torch.manual_seed(42)

    tokens_per_group = batch_size // num_groups
    m_sizes = torch.full(
73
        (num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    )

    x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")

    base_weights = torch.randn(
        num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
    )

    w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
    w_sglang = base_weights

    c_fbgemm = torch.empty(
        batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
    )
    c_sglang = torch.empty(
        batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
    )

92
    seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
93
94
95
    for i in range(1, num_groups + 1):
        seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group

96
    weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
97
98
99
100
101
102
103
104
105
106
107
108
109

    return (
        x,
        w_fbgemm,
        w_sglang,
        c_fbgemm,
        c_sglang,
        m_sizes,
        seg_indptr,
        weight_indices,
    )


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def create_fp8_test_data(
    batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
):
    """
    Create test data for FP8 grouped GEMM operations.

    Args:
        batch_size: Total batch size
        num_groups: Number of groups
        hidden_size: Hidden dimension size
        intermediate_size: Intermediate dimension size
        backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM

    Returns:
        For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
        For cutlass: (x, wq, w_scale, m_sizes)
    """
127
128
129
130
    torch.manual_seed(42)

    tokens_per_group = batch_size // num_groups

131
132
133
134
135
136
137
    # Create weight matrices for each group
    w_list = []
    for _ in range(num_groups):
        w = torch.randn(
            intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
        )
        w_list.append(w)
138

139
140
    # Quantize weights using quantize_fp8_row for each group
    wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
141

142
143
144
145
    if backend == "triton":
        # Triton format: concatenated weights
        w_fp8 = torch.concat(wq_list, dim=0).contiguous()
        w_scale = torch.concat(w_scale_list, dim=0).contiguous()
146

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        # Create m_sizes as int32 for triton
        m_sizes = torch.full(
            (num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
        )

        # Create and quantize input
        x_fp16 = torch.randn(
            batch_size, hidden_size, dtype=torch.float16, device="cuda"
        )
        x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
        x_scale = x_scale.view(batch_size, -1)

        return x_fp8, w_fp8, m_sizes, x_scale, w_scale

    elif backend == "cutlass":
        # CUTLASS format: stacked weights
        wq = torch.stack(wq_list, dim=0).contiguous()
        w_scale = torch.stack(w_scale_list, dim=0).contiguous()

        # Create m_sizes as int64 for cutlass
        m_values = [tokens_per_group] * num_groups
        m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")

        # Create input data - separate for each group then concat
        x_list = []
        for _ in range(num_groups):
            x = torch.randn(
                tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
            )
            x_list.append(x)

        # Concatenate inputs into single tensor
        x = torch.concat(x_list, dim=0).contiguous()

        return x, wq, w_scale, m_sizes

    else:
        raise ValueError(f"Unsupported backend: {backend}")


def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
    """
    Calculate memory bandwidth based on accessed expert weights.

    Args:
        m_sizes: Tensor containing batch sizes for each group
        hidden_size: Hidden dimension size
        intermediate_size: Intermediate dimension size
        dtype: Data type of weights

    Returns:
        Memory size in bytes for accessed expert weights
    """
    # Count non-zero groups (active experts)
    if hasattr(m_sizes, "cpu"):
        active_experts = torch.count_nonzero(m_sizes).item()
    else:
        active_experts = sum(1 for m in m_sizes if m > 0)

    # Calculate bytes per element based on dtype
    if dtype in [torch.float16, torch.bfloat16]:
        bytes_per_element = 2
    elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
        bytes_per_element = 1
    elif dtype == torch.float32:
        bytes_per_element = 4
    else:
        # Default to 2 bytes for unknown dtypes
        bytes_per_element = 2

    # Memory per expert weight matrix
    memory_per_expert = hidden_size * intermediate_size * bytes_per_element

    # Total memory for active experts
    total_memory_bytes = active_experts * memory_per_expert

    return total_memory_bytes
224
225
226
227
228


def get_benchmark_config(use_fp8_w8a8=False):
    if use_fp8_w8a8:
        return {
229
230
231
232
233
234
235
236
237
238
239
            "line_vals": [
                "fbgemm_triton_grouped_gemm_fp8",
                "fbgemm_cutlass_f8f8bf16_rowwise",
                "sglang_grouped_gemm",
            ],
            "line_names": [
                "FBGEMM Triton Grouped GEMM FP8",
                "FBGEMM CUTLASS F8F8BF16 Rowwise",
                "SGLang Grouped GEMM FP8",
            ],
            "styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
240
241
242
        }
    else:
        return {
243
244
245
246
247
            "line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
            "line_names": [
                "FBGEMM Triton Grouped GEMM BF16",
                "SGLang Grouped GEMM BF16",
            ],
248
249
250
251
252
253
254
255
256
257
258
            "styles": [("blue", "-"), ("green", "-")],
        }


def run_benchmark(
    model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
):
    config = get_benchmark_config(use_fp8_w8a8)

    benchmark_config = triton.testing.Benchmark(
        x_names=["batch_size"],
259
        x_vals=[256, 512, 1024, 2048, 4096],
260
261
262
263
        line_arg="provider",
        line_vals=config["line_vals"],
        line_names=config["line_names"],
        styles=config["styles"],
264
        ylabel="Bandwidth (GB/s)",
265
266
267
268
269
270
271
272
273
274
275
276
277
        plot_name="grouped-gemm-performance",
        args={},
    )

    @triton.testing.perf_report(benchmark_config)
    def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
        print(f"Benchmarking {provider} with batch_size={batch_size}")
        torch.cuda.manual_seed_all(0)

        num_groups = model_config["num_groups"]
        hidden_size = model_config["hidden_size"]
        intermediate_size = model_config["intermediate_size"]

278
        if provider == "fbgemm_triton_grouped_gemm_fp8":
279
280
            try:
                test_data = create_fp8_test_data(
281
282
283
284
285
                    batch_size,
                    num_groups,
                    hidden_size,
                    intermediate_size,
                    backend="triton",
286
287
288
                )
                x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data

289
290
291
292
293
                # Calculate memory bandwidth
                memory_bytes = calculate_memory_bandwidth(
                    m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
                )

294
295
296
297
298
299
300
301
                def run_func():
                    return fbgemm_grouped_gemm_fp8_rowwise(
                        x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
                    )

            except Exception as e:
                print(f"FP8 not supported, skipping: {e}")
                return float("inf"), float("inf"), float("inf")
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
            try:
                test_data = create_fp8_test_data(
                    batch_size,
                    num_groups,
                    hidden_size,
                    intermediate_size,
                    backend="cutlass",
                )
                x, wq, w_scale, m_sizes = test_data

                # Calculate memory bandwidth
                memory_bytes = calculate_memory_bandwidth(
                    m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
                )

                # Quantize input using triton_quantize_fp8_row
                xq, x_scale = triton_quantize_fp8_row(x)
                x_scale = x_scale.view(batch_size, -1)

                def run_func():
                    return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
                        xq, wq, x_scale, w_scale, m_sizes
                    )

            except Exception as e:
                print(
                    f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
                    f"skipping: {e}"
                )
                return float("inf"), float("inf"), float("inf")
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        else:
            test_data = create_test_data(
                batch_size, num_groups, hidden_size, intermediate_size
            )
            (
                x,
                w_fbgemm,
                w_sglang,
                c_fbgemm,
                c_sglang,
                m_sizes,
                seg_indptr,
                weight_indices,
            ) = test_data

349
350
351
352
353
354
            # Calculate memory bandwidth for BF16 operations
            memory_bytes = calculate_memory_bandwidth(
                m_sizes, hidden_size, intermediate_size, torch.bfloat16
            )

            if provider == "fbgemm_triton_grouped_gemm":
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386

                def run_func():
                    return fbgemm_grouped_gemm(
                        x, w_fbgemm, m_sizes, use_fast_accum=True
                    )

            else:

                def run_func():
                    return sglang_grouped_gemm(
                        x,
                        w_sglang,
                        c_sglang,
                        num_groups,
                        weight_column_major=True,
                        seg_indptr=seg_indptr,
                        weight_indices=weight_indices,
                        c_dtype=c_sglang.dtype,
                    )

        for _ in range(10):
            try:
                run_func()
            except Exception as e:
                print(f"Error during warmup for {provider}: {e}")
                return float("inf"), float("inf"), float("inf")

        torch.cuda.synchronize()

        try:
            quantiles = [0.5, 0.2, 0.8]
            ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
387
388
389
390
391
392
393
394
395
396

            # Convert time (ms) to bandwidth (GB/s)
            # Bandwidth = Memory (bytes) / Time (seconds)
            # Convert ms to seconds and bytes to GB (1e9)
            gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
            # min bandwidth = max time, max bandwidth = min time
            min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
            max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)

            return gb_per_s, min_gb_per_s, max_gb_per_s
397
398
        except Exception as e:
            print(f"Error during benchmarking for {provider}: {e}")
399
            return 0.0, 0.0, 0.0
400
401
402
403
404
405
406
407
408
409

    dynamic_benchmark.run(
        show_plots=True,
        print_data=True,
        save_path=save_path,
        model_config=model_config,
        use_fp8_w8a8=use_fp8_w8a8,
    )


410
def verify_correctness(model_config):
411
412
413
414
415
416
417
    print("Verifying correctness...")
    batch_size = 128
    num_groups = model_config["num_groups"]
    hidden_size = model_config["hidden_size"]
    intermediate_size = model_config["intermediate_size"]

    test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
418
419
420
421
422
423
424
425
426
427
    (
        x,
        w_fbgemm,
        w_sglang,
        c_fbgemm,
        c_sglang,
        m_sizes,
        seg_indptr,
        weight_indices,
    ) = test_data
428

429
    result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
430

431
432
433
434
435
436
437
438
439
440
    result_sglang = sglang_grouped_gemm(
        x,
        w_sglang,
        c_sglang,
        num_groups,
        weight_column_major=True,
        seg_indptr=seg_indptr,
        weight_indices=weight_indices,
        c_dtype=c_sglang.dtype,
    )
441

442
443
444
445
446
    if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
        print("✓ BF16 Correctness verification passed!")
    else:
        max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
        print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
447
448
        return False

449
450
    return True

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500

def main():
    parser = argparse.ArgumentParser(
        description="Benchmark FBGEMM vs SGLang Grouped GEMM"
    )
    parser.add_argument(
        "--model",
        type=str,
        default="mistralai/Mixtral-8x7B-Instruct-v0.1",
        help="Model name to get configuration from",
    )
    parser.add_argument(
        "--tp-size", type=int, default=1, help="Tensor parallelism size"
    )
    parser.add_argument(
        "--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
    )
    parser.add_argument(
        "--save-path",
        type=str,
        default="./benchmark_grouped_gemm/",
        help="Path to save benchmark results",
    )
    parser.add_argument(
        "--verify-correctness",
        action="store_true",
        help="Verify correctness before benchmarking",
    )

    args = parser.parse_args()

    try:
        model_config = get_model_config(args.model, args.tp_size)
    except Exception as e:
        print(f"Failed to get model config: {e}")
        print("Using default configuration...")
        model_config = {
            "num_groups": 8,
            "hidden_size": 4096,
            "intermediate_size": 14336,
            "dtype": torch.bfloat16,
        }

    print("Running benchmark with:")
    print(f"  num_groups: {model_config['num_groups']}")
    print(f"  hidden_size: {model_config['hidden_size']}")
    print(f"  intermediate_size: {model_config['intermediate_size']}")
    print(f"  use_fp8_w8a8: {args.use_fp8_w8a8}")

    if args.verify_correctness:
501
        if not verify_correctness(model_config):
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            print("Correctness verification failed. Exiting...")
            return

    try:
        run_benchmark(
            model_config=model_config,
            use_fp8_w8a8=args.use_fp8_w8a8,
            save_path=args.save_path,
        )
    except Exception as e:
        print(f"Benchmark failed: {e}")


if __name__ == "__main__":
    main()