benchmark_fp8_block_dense_gemm.py 15.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# fmt: off
# ruff: noqa: E501
import time

# Import DeepGEMM functions
import deep_gemm
import torch
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor

# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    per_token_group_quant_fp8,
    w8a8_block_fp8_matmul,
)
from vllm.triton_utils import triton


# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
def per_token_cast_to_fp8(
        x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Convert tensor to FP8 format with per-token scaling."""
    assert x.dim() == 2 and x.size(1) % 128 == 0
    m, n = x.shape
    x_view = x.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
        torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)


# Copied from
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
def per_block_cast_to_fp8(
        x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """Convert tensor to FP8 format with per-block scaling."""
    assert x.dim() == 2
    m, n = x.shape
    x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
                           dtype=x.dtype,
                           device=x.device)
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
        x_amax / 448.0).view(x_view.size(0), x_view.size(2))


def benchmark_shape(m: int,
                    n: int,
                    k: int,
                    warmup: int = 100,
                    repeat: int = 10000,
                    verbose: bool = False) -> dict:
    """Benchmark all implementations for a specific (m, n, k) shape."""
    if verbose:
        print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")

    # Create test tensors
    A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
    B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)

    # Reference result in BF16
    torch.cuda.synchronize()
    C_ref = A @ B.t()

    # Pre-quantize B for all implementations
    # (weights can be pre-quantized offline)
    B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
    B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)

    # Block size configuration
    block_size = [128, 128]

    # Pre-quantize A for all implementations
    A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
    A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
    C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
    A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
    A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
        A, block_size[1], column_major_scales=True)

    # === DeepGEMM Implementation ===
    def deepgemm_gemm():
        deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
                                       (B_deepgemm, B_scale_deepgemm),
                                       C_deepgemm)
        return C_deepgemm

    # === vLLM Triton Implementation ===
    def vllm_triton_gemm():
        return w8a8_block_fp8_matmul(A_vllm,
                                     B_vllm,
                                     A_scale_vllm,
                                     B_scale_vllm,
                                     block_size,
                                     output_dtype=torch.bfloat16)

    # === vLLM CUTLASS Implementation ===
    def vllm_cutlass_gemm():
        return ops.cutlass_scaled_mm(A_vllm_cutlass,
                                     B_vllm.T,
                                     scale_a=A_scale_vllm_cutlass,
                                     scale_b=B_scale_vllm.T,
                                     out_dtype=torch.bfloat16)

    # Run correctness check first
    if verbose:
        print("Running correctness check...")
    C_deepgemm = deepgemm_gemm()
    C_vllm_triton = vllm_triton_gemm()
    C_vllm_cutlass = vllm_cutlass_gemm()

    deepgemm_diff = calc_diff(C_deepgemm, C_ref)
    vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
    vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)

    if verbose:
        print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
        print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
        print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
        print("vLLM Triton vs DeepGEMM difference: "
              f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
        print("vLLM CUTLASS vs DeepGEMM difference: "
              f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")

    # Benchmark implementations
    implementations = {
        "DeepGEMM": deepgemm_gemm,
        "vLLM Triton": vllm_triton_gemm,
        "vLLM CUTLASS": vllm_cutlass_gemm
    }

    benchmark_results = {
        "shape": {
            "m": m,
            "n": n,
            "k": k
        },
        "implementations": {}
    }

    for name, func in implementations.items():
        # Warmup
        for _ in range(warmup):
            func()
            torch.cuda.synchronize()

        # Timing loop
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(repeat):
            func()
        torch.cuda.synchronize()
        end = time.time()

        # Calculate timing and TFLOPS
        avg_time_ms = (end - start) / repeat * 1000
        avg_time_us = avg_time_ms * 1000
        tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
        gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)

        benchmark_results["implementations"][name] = {
            "time_ms": avg_time_ms,
            "time_us": avg_time_us,
            "tflops": tflops,
            "gb_s": gb_s,
            "diff": {
                "DeepGEMM":
                0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
                "Reference":
                deepgemm_diff if name == "DeepGEMM" else
                (vllm_triton_diff
                 if name == "vLLM Triton" else vllm_cutlass_diff)
            }
        }

        if verbose:
            print(
                f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
            )

    # Calculate speedups
    baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
    for name, data in benchmark_results["implementations"].items():
        if name != "DeepGEMM":
            speedup = baseline / data["time_ms"]
            benchmark_results["implementations"][name][
                "speedup_vs_deepgemm"] = speedup
            if verbose:
                print(f"DeepGEMM is {1/speedup:.2f}x "
                      f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")

    vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
        "time_ms"]
    vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
        "time_ms"]
    cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
    benchmark_results["implementations"]["vLLM CUTLASS"][
        "speedup_vs_triton"] = cutlass_vs_triton
    if verbose:
        print(
            f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
            f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
        )

    return benchmark_results


def format_table_row(values, widths):
    """Format a row with specified column widths."""
    return "| " + " | ".join(f"{val:{w}}"
                             for val, w in zip(values, widths)) + " |"


def print_table(headers, rows, title=None):
    """Print a table with headers and rows."""
    if title:
        print(f"\n{title}")

    # Calculate column widths based on headers and data
    widths = [
        max(len(str(h)), max(len(str(row[i])) for row in rows))
        for i, h in enumerate(headers)
    ]

    # Create separator line
    separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"

    # Print table
    print(separator)
    print(format_table_row(headers, widths))
    print(separator)
    for row in rows:
        print(format_table_row(row, widths))
    print(separator)


def format_speedup(value):
    """Format speedup value with indicator if it's faster or slower."""
    return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"


def run_benchmarks(verbose: bool = False):
    """Run benchmarks for a set of common shapes."""
    print("===== STARTING FP8 GEMM BENCHMARK =====")

    # Make sure we're using the GPU
    if not torch.cuda.is_available():
        print("CUDA not available! Tests require GPU.")
        return

    # Print system information
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Triton version: {triton.__version__}")
    print(f"Using device: {torch.cuda.get_device_name()}")

    # Enable TF32 for better performance
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Set seeds for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

    # Define benchmark shapes (m, n, k)
    shapes = [
        (8, 4096, 7168),
        (8, 7168, 18432),
        (8, 18432, 7168),
        (64, 4096, 7168),
        (64, 7168, 18432),
        (64, 18432, 7168),
        (64, 24576, 1536),
        (64, 32768, 512),
        (64, 7168, 16384),
        (128, 4096, 7168),
        (128, 7168, 18432),
        (128, 18432, 7168),
        (1024, 4096, 7168),
        (1024, 18432, 7168),
        (2048, 4096, 7168),
        (4096, 4096, 7168),
    ]
    shapes = [
        # (64, 2112, 7168),
        (64, 24576, 1536),
        (64, 32768, 512),
        (64, 7168, 16384),
        (64, 4096, 7168),
        (64, 7168, 2048),
        # (128, 2112, 7168),
        (128, 24576, 1536),
        (128, 32768, 512),
        (128, 7168, 16384),
        (128, 4096, 7168),
        (128, 7168, 2048),
        # (4096, 2112, 7168),
        (4096, 24576, 1536),
        (4096, 32768, 512),
        (4096, 7168, 16384),
        (4096, 4096, 7168),
        (4096, 7168, 2048),
    ]

    all_results = []
    for m, n, k in shapes:
        result = benchmark_shape(m, n, k, verbose=verbose)
        all_results.append(result)

    # Print results in a nicely formatted table
    print("\n===== PERFORMANCE COMPARISON =====")

    # Print DeepGEMM table
    deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
    deepgemm_rows = []
    for result in all_results:
        shape = result["shape"]
        impl_data = result["implementations"]["DeepGEMM"]
        deepgemm_rows.append([
            shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
            f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
        ])

    print_table(deepgemm_headers,
                deepgemm_rows,
                title="DeepGEMM Implementation:")

    # Print vLLM Triton table
    triton_headers = [
        "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
    ]
    triton_rows = []
    for result in all_results:
        shape = result["shape"]
        impl_data = result["implementations"]["vLLM Triton"]
        speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
        triton_rows.append([
            shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
            f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
            format_speedup(speedup)
        ])

    print_table(triton_headers,
                triton_rows,
                title="vLLM Triton Implementation:")

    # Print vLLM CUTLASS table
    cutlass_headers = [
        "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
        "vs Triton"
    ]
    cutlass_rows = []
    for result in all_results:
        shape = result["shape"]
        impl_data = result["implementations"]["vLLM CUTLASS"]
        vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
        vs_triton = impl_data.get("speedup_vs_triton", 1.0)
        cutlass_rows.append([
            shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
            f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
            format_speedup(vs_deepgemm),
            format_speedup(vs_triton)
        ])

    print_table(cutlass_headers,
                cutlass_rows,
                title="vLLM CUTLASS Implementation:")

    # Calculate and print averages
    print("\n===== AVERAGE PERFORMANCE =====")

    implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
    avg_metrics = {
        impl: {
            "tflops": 0,
            "gb_s": 0,
            "time_ms": 0
        }
        for impl in implementations
    }

    for result in all_results:
        for impl in implementations:
            impl_data = result["implementations"][impl]
            avg_metrics[impl]["tflops"] += impl_data["tflops"]
            avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
            avg_metrics[impl]["time_ms"] += impl_data["time_ms"]

    num_shapes = len(all_results)
    avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
    avg_rows = []

    for impl in implementations:
        avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
        avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
        avg_time = avg_metrics[impl]["time_ms"] / num_shapes
        avg_rows.append([
            impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
        ])

    print_table(avg_headers, avg_rows)

    # Calculate average speedups
    avg_speedups = {
        "DeepGEMM vs vLLM Triton": 0,
        "DeepGEMM vs vLLM CUTLASS": 0,
        "vLLM CUTLASS vs vLLM Triton": 0
    }

    for result in all_results:
        deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
        vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
        vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
            "time_ms"]

        avg_speedups[
            "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
        avg_speedups[
            "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
        avg_speedups[
            "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time

    print("\n===== AVERAGE SPEEDUPS =====")
    speedup_headers = ["Comparison", "Speedup"]
    speedup_rows = []
    for comparison, total in avg_speedups.items():
        avg_speedup = total / num_shapes
        status = "faster" if avg_speedup > 1 else "slower"
        speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])

    print_table(speedup_headers, speedup_rows)

    # Average accuracy comparison
    print("\n===== ACCURACY COMPARISON =====")
    avg_diff = {impl: 0 for impl in implementations}

    for result in all_results:
        for impl in implementations:
            avg_diff[impl] += result["implementations"][impl]["diff"][
                "Reference"]

    diff_headers = ["Implementation", "Avg Diff vs Reference"]
    diff_rows = []
    for impl in implementations:
        diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])

    print_table(diff_headers, diff_rows)


if __name__ == "__main__":
    run_benchmarks(verbose=False)