benchmark_deepgemm_fp8_gemm.py 12.8 KB
Newer Older
1
2
3
from typing import Tuple

import deep_gemm
4
5
import tilelang
import tilelang.language as T
6
7
import torch
import triton
8
9
from deep_gemm import ceil_div
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
10
11
12
13
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
    w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)

14
15
16
from sglang.srt.layers.quantization.fp8_kernel import (
    w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul,
)
17
18


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
def tl_gemm(
    M,
    N,
    K,
    in_dtype,
    out_dtype,
    accum_dtype,
):
    assert in_dtype in [
        "e4m3_float8",
    ], "Currently only e4m3_float8 is supported"
    assert out_dtype in [
        "bfloat16",
        "float16",
    ], "Currently only bfloat16 and float16 are supported"

    TILE_SIZE = (128, 128, 128)
    block_M = TILE_SIZE[0]
    block_N = TILE_SIZE[1]
    block_K = TILE_SIZE[2]

    A_shape = (M, K)
    Scales_A_shape = (M, T.ceildiv(K, block_K))
    B_shape = (N, K)
    Scales_B_shape = (T.ceildiv(N, block_N), T.ceildiv(K, block_K))
    A_shared_shape = (block_M, block_K)
    B_shared_shape = (block_N, block_K)
    C_shared_shape = (block_M, block_N)

    @T.prim_func
    def main(
        A: T.Buffer(A_shape, in_dtype),
        scales_a: T.Buffer(Scales_A_shape, "float32"),
        B: T.Buffer(B_shape, in_dtype),
        scales_b: T.Buffer(Scales_B_shape, "float32"),
        C: T.Buffer((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
            bx,
            by,
        ):

            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_shared = T.alloc_shared(C_shared_shape, out_dtype)
            Scale_C_shared = T.alloc_shared((block_M), "float32")
            C_local = T.alloc_fragment(C_shared_shape, accum_dtype)
            C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype)

            # Improve L2 Cache
            T.use_swizzle(panel_size=10)

            T.clear(C_local)
            T.clear(C_local_accum)
            K_iters = T.ceildiv(K, block_K)
            for k in T.Pipelined(K_iters, num_stages=4):
                # Load A into shared memory
                T.copy(A[by * block_M, k * block_K], A_shared)
                # Load B into shared memory
                T.copy(B[bx * block_N, k * block_K], B_shared)
                # Load scale into shared memory
                Scale_B = scales_b[bx, k]
                for i in T.Parallel(block_M):
                    Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B

                T.gemm(A_shared, B_shared, C_local, transpose_B=True)
                # Promote to enable 2xAcc
                for i, j in T.Parallel(block_M, block_N):
                    C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
                T.clear(C_local)
            # TMA store
            T.copy(C_local_accum, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

    return main


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2 and x.size(1) % 128 == 0
    m, n = x.shape
    x_view = x.view(m, -1, 128)
    x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
    return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
        m, n
    ), (x_amax / 448.0).view(m, -1)


def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    assert x.dim() == 2
    m, n = x.shape
    x_padded = torch.zeros(
        (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
    )
    x_padded[:m, :n] = x
    x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
    x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
    x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
    return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
        x_view.size(0), x_view.size(2)
    )


def fp8_gemm_deepgemm(
    x_fp8: torch.Tensor,
    x_scale: torch.Tensor,
    y_fp8: torch.Tensor,
    y_scale: torch.Tensor,
    m: int,
    n: int,
    k: int,
):
    """DeepGEMM implementation of FP8 GEMM"""
    out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)

    # Run DeepGEMM kernel
135
    deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    return out


def fp8_gemm_sglang(
    x_fp8: torch.Tensor,
    x_scale: torch.Tensor,
    y_fp8: torch.Tensor,
    y_scale: torch.Tensor,
    m: int,
    n: int,
    k: int,
):
    """SGLang implementation of FP8 GEMM"""
    block_size = [128, 128]  # Matches the block size in per_block_cast_to_fp8

    # Run SGLang kernel
    out = w8a8_block_fp8_matmul(
        x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
    )
    return out


def fp8_gemm_vllm(
    x_fp8: torch.Tensor,
    x_scale: torch.Tensor,
    y_fp8: torch.Tensor,
    y_scale: torch.Tensor,
    m: int,
    n: int,
    k: int,
):
    """vLLM implementation of FP8 GEMM"""
    block_size = [128, 128]  # Matches the block size in per_block_cast_to_fp8

    # Run vLLM kernel
    out = vllm_w8a8_block_fp8_matmul(
        x_fp8, y_fp8, x_scale, y_scale, block_size, torch.bfloat16
    )
    return out


def calculate_diff(m: int, n: int, k: int):
    x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
    y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)

    x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
    y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
183
    x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
184
185
186
187
188
189
190
191
192
193
194
195
196

    out_deepgemm = fp8_gemm_deepgemm(
        x_fp8.clone(),
        x_scale_col_major.clone(),
        y_fp8.clone(),
        y_scale.clone(),
        m,
        n,
        k,
    )
    out_sglang = fp8_gemm_sglang(
        x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone(), m, n, k
    )
197
198
199
200
201

    tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
    tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
    out_tilelang = tilelang_kernel(
        x_fp8.clone(), x_scale.clone(), y_fp8.clone(), y_scale.clone()
202
203
204
    )

    diff_sglang_deepgemm = torch.abs(out_deepgemm - out_sglang).mean().item()
205
206
    diff_tilelang_deepgemm = torch.abs(out_deepgemm - out_tilelang).mean().item()
    diff_tilelang_sglang = torch.abs(out_tilelang - out_sglang).mean().item()
207
208
209
210

    print(f"Shape m={m}, n={n}, k={k}:")
    print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
    print(f"SGLang output: {out_sglang[0, 0:5]}")
211
    print(f"TileLang output: {out_tilelang[0, 0:5]}")
212
    print(f"Mean absolute difference (SGLang-DeepGEMM): {diff_sglang_deepgemm}")
213
214
    print(f"Mean absolute difference (TileLang-DeepGEMM): {diff_tilelang_deepgemm}")
    print(f"Mean absolute difference (TileLang-SGLang): {diff_tilelang_sglang}")
215
216
217
218

    sglang_deepgemm_match = torch.allclose(
        out_deepgemm, out_sglang, atol=1e-2, rtol=1e-2
    )
219
220
221
222
223
224
    tilelang_deepgemm_match = torch.allclose(
        out_deepgemm, out_tilelang, atol=1e-2, rtol=1e-2
    )
    tilelang_sglang_match = torch.allclose(
        out_tilelang, out_sglang, atol=1e-2, rtol=1e-2
    )
225

226
    if sglang_deepgemm_match and tilelang_deepgemm_match and tilelang_sglang_match:
227
228
229
230
        print("✅ All implementations match\n")
    else:
        print("❌ Some implementations differ:")
        print(f"  - SGLang vs DeepGEMM: {'✅' if sglang_deepgemm_match else '❌'}")
231
232
        print(f"  - TileLang vs DeepGEMM: {'✅' if tilelang_deepgemm_match else '❌'}")
        print(f"  - TileLang vs SGLang: {'✅' if tilelang_sglang_match else '❌'}\n")
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287


def get_weight_shapes(tp_size):
    # cannot TP
    total = [
        (512 + 64, 7168),
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (7168, 16384),
        (7168, 18432),
    ]
    # N can TP
    n_tp = [
        (18432 * 2, 7168),
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (24576, 1536),
        (4096, 7168),
    ]
    # K can TP
    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]

    weight_shapes = []
    for t in total:
        weight_shapes.append(t)
    for n_t in n_tp:
        new_t = (n_t[0] // tp_size, n_t[1])
        weight_shapes.append(new_t)
    for k_t in k_tp:
        new_t = (k_t[0], k_t[1] // tp_size)
        weight_shapes.append(new_t)

    return weight_shapes


def create_benchmark_configs(tp_size):
    configs = []
    weight_shapes = get_weight_shapes(tp_size)
    batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]

    for n, k in weight_shapes:
        for m in batch_sizes:
            configs.append((m, n, k, tp_size))

    return configs


def get_benchmark(tp_size):
    all_configs = create_benchmark_configs(tp_size)

    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["m", "n", "k", "tp_size"],
            x_vals=[list(config) for config in all_configs],
            line_arg="provider",
288
289
            line_vals=["deepgemm", "sglang", "tilelang"],
            line_names=["DeepGEMM", "SGLang", "TileLang"],
290
291
292
293
294
295
296
297
298
299
300
            styles=[("blue", "-"), ("red", "-"), ("green", "-")],
            ylabel="ms",
            plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
            args={},
        )
    )
    def benchmark(m, n, k, tp_size, provider):
        print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
        x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
        y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)

301
        # Preprocess data before benchmarking
302
303
        x_fp8, x_scale = per_token_cast_to_fp8(x)
        y_fp8, y_scale = per_block_cast_to_fp8(y)
304
        x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        quantiles = [0.5, 0.2, 0.8]

        if provider == "deepgemm":
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: fp8_gemm_deepgemm(
                    x_fp8.clone(),
                    x_scale_col_major.clone(),
                    y_fp8.clone(),
                    y_scale.clone(),
                    m,
                    n,
                    k,
                ),
                quantiles=quantiles,
            )
        elif provider == "sglang":
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: fp8_gemm_sglang(
                    x_fp8.clone(),
                    x_scale.clone(),
                    y_fp8.clone(),
                    y_scale.clone(),
                    m,
                    n,
                    k,
                ),
                quantiles=quantiles,
            )
334
335
336
        else:  # tilelang
            tilelang_func = tl_gemm(m, n, k, "e4m3_float8", "bfloat16", "float32")
            tilelang_kernel = tilelang.compile(tilelang_func, out_idx=[-1])
337
            ms, min_ms, max_ms = triton.testing.do_bench(
338
                lambda: tilelang_kernel(
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
                    x_fp8.clone(),
                    x_scale.clone(),
                    y_fp8.clone(),
                    y_scale.clone(),
                ),
                quantiles=quantiles,
            )

        # Calculate TFLOPS
        flops = 2 * m * n * k  # multiply-adds
        tflops = flops / (ms * 1e-3) / 1e12

        # Print shape-specific results with TFLOPS
        print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
        return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms

    return benchmark


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_path",
        type=str,
        default="./configs/benchmark_ops/fp8_gemm/",
        help="Path to save fp8 gemm benchmark results",
    )
    parser.add_argument(
        "--run_correctness",
        action="store_true",
371
        default=True,
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        help="Whether to run correctness test",
    )
    parser.add_argument(
        "--tp_size",
        type=int,
        default=1,
        help="Tensor parallelism size to benchmark (default: 1)",
    )
    args = parser.parse_args()

    # Set random seed for reproducibility
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    # Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Run correctness tests on a few examples
    if args.run_correctness:
        print("Running correctness tests...")
        calculate_diff(64, 512, 7168)  # Small test
        calculate_diff(64, 7168, 16384)  # Medium test
        calculate_diff(64, 18432, 7168)  # Large test

    # Get the benchmark function with the specified tp_size
    benchmark = get_benchmark(args.tp_size)

    print(f"Running performance benchmark for TP size = {args.tp_size}...")
    benchmark.run(print_data=True, save_path=args.save_path)