benchmark_deepgemm_fp8_group_gemm.py 15.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Tuple

import deep_gemm
import torch
import triton
import triton.language as tl
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor

# Import shared functionality from the regular GEMM benchmark
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
    per_block_cast_to_fp8,
    per_token_cast_to_fp8,
)


def construct_grouped_and_flat_fp8(
    x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
) -> Tuple[
    Tuple[torch.Tensor, torch.Tensor],  # grouped x_fp8
    Tuple[torch.Tensor, torch.Tensor],  # grouped y_fp8
    Tuple[torch.Tensor, torch.Tensor],  # flat x_fp8
    Tuple[torch.Tensor, torch.Tensor],  # flat y_fp8
    torch.Tensor,  # output
    torch.Tensor,  # reference output
]:
    # Verify input shapes
    m, k = x.shape
    n, k_y = y.shape
    assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
    assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
    assert m % 4 == 0, f"TMA alignment error: {m}"

    # Reshape inputs for grouped processing
    m_per_group = m // num_groups
    x_grouped = x.view(num_groups, m_per_group, k)
    y_grouped = y.unsqueeze(0).expand(num_groups, n, k)

    # Initialize output tensors
    out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
    ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)

    # Quantize grouped tensors
    x_fp8_grouped = (
        torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
        torch.empty(
            (num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
        ),
    )
    y_fp8_grouped = (
        torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
        torch.empty(
            (num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
        ),
    )
    for i in range(num_groups):
        x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
        y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])

    # Quantize flat tensors
    x_fp8_flat = per_token_cast_to_fp8(x)
    y_fp8_flat = per_block_cast_to_fp8(y)

    # For non-masked input, merge the group and M dims in output
    if not is_masked:
        x_fp8_grouped = (
            x_fp8_grouped[0].view(-1, k),
            per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
        )
        out, ref_out = out.view(-1, n), ref_out.view(-1, n)

    # Transpose earlier for testing
    x_fp8_grouped = (
        x_fp8_grouped[0],
        get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
    )
    x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))

    return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out


# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
# custom kernel based on the Triton tutorial.
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
@triton.jit
def fp8_gemm_group_triton_kernel(
    # Pointers to matrices
    a_ptr,
    b_ptr,
    c_ptr,
    # Pointers to scaling factors
    a_scale_ptr,
    b_scale_ptr,
    # Matrix dimensions
    M,
    N,
    K,
    # The stride variables represent how much to increase the ptr by when moving by 1
    # element in a particular dimension.
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    # Strides for scaling factors
    stride_a_scale_m,
    stride_a_scale_k,
    stride_b_scale_n,
    stride_b_scale_k,
    # Meta-parameters
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
):
    """Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
118
119

    Note: Block sizes must be multiples of 32 for optimal TMA performance.
120
121
    """
    # Map program ids to the block of C it should compute
122
123
124
125
126
127
128
    pid_group = tl.program_id(axis=0)  # Group ID
    pid_n = tl.program_id(axis=1)  # N dimension ID

    # Compute the M block ID within this group
    group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
    pid_m_within_group = tl.program_id(axis=2) % group_size_m
    pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

    # Create pointers for the first blocks of A and B
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # Initialize accumulator
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

    # Main loop
    for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        k_offset = k_block * BLOCK_SIZE_K

        # Load the next block of A and B, with masks
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)

        # Calculate indices for scaling factors for this K block
        a_scale_ptrs = a_scale_ptr + (
            offs_am * stride_a_scale_m + k_block * stride_a_scale_k
        )
        b_scale_ptrs = b_scale_ptr + (
            pid_n * stride_b_scale_n + k_block * stride_b_scale_k
        )

156
157
158
        # Perform matrix multiplication in FP8
        res = tl.dot(a, b)

159
160
161
162
        # Load scaling factors for the current block
        a_scale = tl.load(a_scale_ptrs)[:, None]  # [BLOCK_SIZE_M, 1]
        b_scale = tl.load(b_scale_ptrs)

163
164
        # Apply scaling factors to the accumulated result
        accumulator += res * a_scale * b_scale
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

        # Advance pointers
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    # Convert to bfloat16 for output
    c = accumulator.to(tl.bfloat16)

    # Write back the result
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


181
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
182
183
184
185
186
187
    """
    Perform matrix multiplication with FP8 inputs and proper scaling.

    Args:
        a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
        b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
188
        c: Output tensor in BF16 format
189
190
191
192
193
194
195
196
197
198
        num_groups: Number of groups for grouped GEMM

    Returns:
        Result tensor in BF16 format
    """
    # Unpack the tuples
    a, a_scale = a_tuple
    b, b_scale = b_tuple

    M, K = a.shape
199
    _, N = b.shape
200

201
202
203
204
    # Configure block sizes - must be multiples of 32 for TMA alignment
    BLOCK_SIZE_M = 128
    BLOCK_SIZE_N = 128
    BLOCK_SIZE_K = 128
205

206
207
208
209
    # Calculate grid dimensions
    num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
    num_groups_grid = triton.cdiv(num_pid_m, num_groups)
210

211
212
    # 3D grid launch - (group, n_blocks, m_blocks_per_group)
    grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

    fp8_gemm_group_triton_kernel[grid](
        a,
        b,
        c,
        a_scale,
        b_scale,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        a_scale.stride(0),
        1,  # Stride in the K dimension may be 1
        b_scale.stride(0),
        1 if b_scale.dim() > 1 else 0,
233
234
235
        BLOCK_SIZE_M=BLOCK_SIZE_M,
        BLOCK_SIZE_N=BLOCK_SIZE_N,
        BLOCK_SIZE_K=BLOCK_SIZE_K,
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        GROUP_SIZE_M=num_groups,
    )

    return c


def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
    deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
        x_fp8_grouped,
        y_fp8_grouped,
        out,
        m_indices,
    )
    return out


def calculate_diff(m: int, n: int, k: int, num_groups: int):
    print(f"Shape (m={m}, n={n}, k={k}")
    x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
    y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
    x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
        construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
    )
    m_per_group = m // num_groups
    out_deepgemm = out.clone()
    m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
    m_indices = (
        m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
    )

    fp8_gemm_group_deepgemm(
        x_fp8_grouped,
        y_fp8_grouped,
        out_deepgemm,
        m_indices,
    )
    torch.cuda.synchronize()

274
275
276
277
278
279
280
281
282
283
    # Prepare inputs for Triton
    a, a_scale = x_fp8_flat
    b, b_scale = y_fp8_flat
    b = b.T.contiguous()
    # Ensure scales are in the right format and contiguous
    a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
    M, _ = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
    out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    torch.cuda.synchronize()

    diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
    diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
    diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()

    print(f"Shape m={m}, n={n}, k={k}:")
    print(f"Torch output: {out_torch[0, 0:5]}")
    print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
    print(f"Triton output: {out_triton[0, 0:5]}")
    print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
    print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
    print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")

    deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
    triton_torch_diff = calc_diff(out_triton, out_torch)
    deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)

    DIFF_THRESHOLD = 0.001
    all_match = (
        deepgemm_torch_diff < DIFF_THRESHOLD
        and triton_torch_diff < DIFF_THRESHOLD
        and deepgemm_triton_diff < DIFF_THRESHOLD
    )
    if all_match:
        print("✅ All implementations match\n")
    else:
        print("❌ Some implementations differ:")
        print(
            f"  - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
            f"  - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
            f"  - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
        )


319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def get_weight_shapes(tp_size):
    # cannot TP
    total = [
        (512 + 64, 7168),
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (7168, 16384),
        (7168, 18432),
    ]
    # N can TP
    n_tp = [
        (18432 * 2, 7168),
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (24576, 1536),
        (4096, 7168),
    ]
    # K can TP
    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]

    weight_shapes = []
    for t in total:
        weight_shapes.append(t)
    for n_t in n_tp:
        new_t = (n_t[0] // tp_size, n_t[1])
        weight_shapes.append(new_t)
    for k_t in k_tp:
        new_t = (k_t[0], k_t[1] // tp_size)
        weight_shapes.append(new_t)

    return weight_shapes


def create_benchmark_configs(tp_size):
    configs = []
    weight_shapes = get_weight_shapes(tp_size)
    batch_sizes = [2048, 4096]
    group_sizes = [4, 8]
    for n, k in weight_shapes:
        for m in batch_sizes:
            for num_groups in group_sizes:
                configs.append((m, n, k, num_groups, tp_size))

    return configs


365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def get_benchmark(tp_size):
    all_configs = create_benchmark_configs(tp_size)

    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["m", "n", "k", "num_groups", "tp_size"],
            x_vals=[config for config in all_configs],
            line_arg="provider",
            line_vals=["deepgemm", "triton"],
            line_names=["DeepGEMM", "Triton"],
            styles=[("blue", "-"), ("red", "-")],
            ylabel="ms",
            plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
            args={},
        )
    )
    def benchmark(m, n, k, num_groups, tp_size, provider):
        print(
            f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
        )
        x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
        y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
        x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
            construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
        )
        m_per_group = m // num_groups
        m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
        m_indices = (
            m_indices.unsqueeze(-1)
            .expand(num_groups, m_per_group)
            .contiguous()
            .view(-1)
        )

        quantiles = [0.5, 0.2, 0.8]

        if provider == "deepgemm":
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: fp8_gemm_group_deepgemm(
                    x_fp8_grouped,
                    y_fp8_grouped,
                    out,
                    m_indices,
                ),
                quantiles=quantiles,
            )
        elif provider == "triton":
412
413
414
415
416
417
418
419
420
421
            # Prepare inputs for Triton
            # We did it outside of the lambda function to make it fair comparison like deepgemm
            a, a_scale = x_fp8_flat
            b, b_scale = y_fp8_flat
            b = b.T.contiguous()
            # Ensure scales are in the right format and contiguous
            a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
            M, _ = a.shape
            _, N = b.shape
            c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
422
423
            ms, min_ms, max_ms = triton.testing.do_bench(
                lambda: fp8_gemm_group_triton(
424
425
426
                    (a, a_scale),
                    (b, b_scale),
                    c,
427
428
429
430
431
432
433
434
435
                    num_groups,
                ),
                quantiles=quantiles,
            )

        # Calculate TFLOPS
        flops = 2 * m * n * k  # multiply-adds
        tflops = flops / (ms * 1e-3) / 1e12

436
437
        print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
        return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479

    return benchmark


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save_path",
        type=str,
        default="./configs/benchmark_ops/fp8_group_gemm/",
        help="Path to save deepgemm fp8 group gemm benchmark results",
    )
    parser.add_argument(
        "--run_correctness",
        action="store_true",
        help="Whether to run correctness test",
    )
    parser.add_argument(
        "--tp_size",
        type=int,
        default=1,
        help="Tensor parallelism size to benchmark (default: 1)",
    )
    args = parser.parse_args()

    # Set random seed for reproducibility
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    # Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Run correctness tests on a few examples
    if args.run_correctness:
        print("Running correctness tests...")
        calculate_diff(8192, 7168, 4096, 4)
        calculate_diff(8192, 2048, 7168, 4)
        calculate_diff(4096, 7168, 4096, 8)
        calculate_diff(4096, 2048, 7168, 8)
480
        calculate_diff(4096, 576, 7168, 8)
481
482
483
484
485
486

    # Get the benchmark function with the specified tp_size
    benchmark = get_benchmark(args.tp_size)

    print(f"Running performance benchmark for TP size = {args.tp_size}...")
    benchmark.run(print_data=True, save_path=args.save_path)