bench_fp8_blockwise_gemm.py 6.91 KB
Newer Older
1
2
3
import argparse
import copy
import itertools
4
import os
5

6
import deep_gemm
7
8
import torch
import triton
9
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
10
from sgl_kernel import fp8_blockwise_scaled_mm
11
12
13
14
15
16
17
18
19

# Optional vLLM import
try:
    from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm

    VLLM_AVAILABLE = True
except ImportError:
    vllm_scaled_mm = None
    VLLM_AVAILABLE = False
20

21
22
23
from sglang.srt.layers.quantization.fp8_kernel import (
    w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul,
)
24

25
26
27
28
29
30
# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)

31
32
33
34
35
36

def get_weight_shapes(args):
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
    # cannot TP
    total = [
37
        (512 + 64, 7168),
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (7168, 16384),
        (7168, 18432),
    ]
    # N can TP
    n_tp = [
        (18432 * 2, 7168),
        ((128 + 64) * 128, 7168),
        (128 * (128 + 128), 512),
        (24576, 1536),
        (4096, 7168),
    ]
    # K can TP
    k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
    # only support Deepseek-V3
    SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]

    weight_shapes = []
    for model, tp_size in models_tps:
        assert model in SUPPORT_MODEL
        for t in total:
            new_t = [t[0], t[1], model]
            weight_shapes.append(new_t)
        for n_t in n_tp:
            new_t = [n_t[0] // tp_size, n_t[1], model]
            weight_shapes.append(new_t)
        for k_t in k_tp:
            new_t = [k_t[0], k_t[1] // tp_size, model]
            weight_shapes.append(new_t)
    return weight_shapes


def cdiv(a: int, b: int) -> int:
    """Ceiling division."""
    return -(a // -b)


76
77
78
79
80
81
82
83
84
85
86
87
88
def fp8_gemm_deepgemm(
    x_fp8: torch.Tensor,
    x_scale: torch.Tensor,
    y_fp8: torch.Tensor,
    y_scale: torch.Tensor,
    m: int,
    n: int,
    k: int,
):
    """DeepGEMM implementation of FP8 GEMM"""
    out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)

    # Run DeepGEMM kernel
89
    deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
90
91
92
    return out


93
94
95
96
97
def scale_shape(shape, group_shape):
    assert len(shape) == len(group_shape)
    return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# CI environment uses simplified parameters
if IS_CI:
    batch_sizes = [1, 8]  # Simplified for CI
else:
    batch_sizes = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]

# Filter providers based on availability
available_providers = ["sgl-kernel"]
available_names = ["sgl-kernel"]
available_styles = [("orange", "-")]

if VLLM_AVAILABLE:
    available_providers.insert(0, "vllm")
    available_names.insert(0, "vllm")
    available_styles.insert(0, ("blue", "-"))

available_providers.append("triton")
available_names.append("sglang triton")
available_styles.append(("red", "-"))

# Add deepgemm if available
try:
    import deep_gemm

    available_providers.append("deepgemm")
    available_names.append("deepgemm")
    available_styles.append(("yellow", "-"))
except ImportError:
    pass


129
130
131
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
132
        x_vals=batch_sizes,
133
134
        x_log=False,
        line_arg="provider",
135
136
137
        line_vals=available_providers,
        line_names=available_names,
        styles=available_styles,
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        ylabel="GB/s",
        plot_name="fp8 blockwise scaled matmul",
        args={},
    )
)
def benchmark(batch_size, provider, N, K):
    M = batch_size
    fp8_info = torch.finfo(torch.float8_e4m3fn)
    fp8_max, fp8_min = fp8_info.max, fp8_info.min

    a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
    a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

    b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
152
    b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
153
154
155
156
157
158
159
160
161
162
163

    scale_a_group_shape = (1, 128)
    scale_b_group_shape = (128, 128)
    scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
    scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)

    scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
    scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)

    quantiles = [0.5, 0.2, 0.8]
    if provider == "sgl-kernel":
164
165
        scale_a = scale_a.t().contiguous().t()
        b_fp8, scale_b = b_fp8.t(), scale_b.t()
166
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
167
168
169
170
171
            lambda: fp8_blockwise_scaled_mm(
                a_fp8, b_fp8, scale_a, scale_b, torch.float16
            ),
            quantiles=quantiles,
        )
172
173
174
    elif provider == "vllm":
        if not VLLM_AVAILABLE:
            return (0, 0, 0)
175
176
        scale_a = scale_a.t().contiguous().t()
        b_fp8, scale_b = b_fp8.t(), scale_b.t()
177
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
178
179
180
            lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
            quantiles=quantiles,
        )
181
    elif provider == "triton":
182
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
183
184
185
186
187
188
            lambda: w8a8_block_fp8_matmul(
                a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
            ),
            quantiles=quantiles,
        )
    if provider == "deepgemm":
189
190
        scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone())
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
191
192
193
194
            lambda: fp8_gemm_deepgemm(
                a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
            ),
            quantiles=quantiles,
195
        )
196
    return ms * 1000, max_ms * 1000, min_ms * 1000  # convert to ms
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=["deepseek-ai/DeepSeek-V3"],
        help="List of models to benchmark",
    )
    parser.add_argument(
        "--tp-sizes",
        nargs="+",
        type=int,
        default=[1],
        help="List of tensor parallel sizes",
    )
    args = parser.parse_args()

217
218
219
220
221
    # Simplify for CI environment
    if IS_CI:
        args.models = [args.models[0]]  # Use only first model
        args.tp_sizes = [args.tp_sizes[0]]  # Use only first TP size

222
    NK_model_names = get_weight_shapes(args)
223
224
225
226
227

    # Limit iterations in CI
    if IS_CI:
        NK_model_names = NK_model_names[:2]  # Only test first 2 shapes in CI

228
    for N, K, model_name in NK_model_names:
229
230
231
        if N % 128 != 0 or K % 128 != 0:
            print(f"Skip {N=}, {K=} now")
            continue
232
233
234
235
236
237
238
239
        print(f"{model_name} N={N} K={K}: ")
        benchmark.run(
            print_data=True,
            N=N,
            K=K,
        )

    print("Benchmark finished!")