bench_fp8_gemm.py 6.68 KB
Newer Older
1
2
3
import argparse
import copy
import itertools
4
import os
5
from typing import Optional, Tuple
6
7
8
9

import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
10
from sgl_kernel import sgl_per_tensor_quant_fp8
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

# Optional vLLM import
try:
    from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
    from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant

    VLLM_AVAILABLE = True
except ImportError:
    vllm_scaled_mm = None
    vllm_scaled_fp8_quant = None
    VLLM_AVAILABLE = False

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
#  A shape of ([14336, 4096], 0) indicates the following GEMM shape,
#   - TP1 : K = 14336, N = 4096
#   - TP2 : K = 7168, N = 4096
#  A shape of ([4096, 6144], 1) indicates the following GEMM shape,
#   - TP1 : K = 4096, N = 6144
#   - TP4 : K = 4096, N = 1536

# TP1 shapes
WEIGHT_SHAPES = {
    "meta-llama/Llama-3.1-8B-Instruct": [
        ([4096, 6144], 1),
        ([4096, 4096], 0),
        ([4096, 28672], 1),
        ([14336, 4096], 0),
    ],
    "meta-llama/Llama-3.3-70B-Instruct": [
        ([8192, 10240], 1),
        ([8192, 8192], 0),
        ([8192, 57344], 1),
        ([28672, 8192], 0),
    ],
    "mistralai/Mistral-Large-Instruct-2407": [
        ([12288, 14336], 1),
        ([12288, 12288], 0),
        ([12288, 57344], 1),
        ([28672, 12288], 0),
    ],
    "Qwen/Qwen2.5-7B-Instruct": [
        ([3584, 4608], 1),
        ([3584, 3584], 0),
        ([3584, 37888], 1),
        ([18944, 3584], 0),
    ],
    "Qwen/Qwen2.5-32B-Instruct": [
        ([5120, 7168], 1),
        ([5120, 5120], 0),
        ([5120, 55296], 1),
        ([27648, 5120], 0),
    ],
    "Qwen/Qwen2.5-72B-Instruct": [
        ([8192, 10240], 1),
        ([8192, 8192], 0),
        ([8192, 59136], 1),
        ([29568, 8192], 0),
    ],
    "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
        ([2048, 3072], 1),
        ([2048, 4096], 1),
        ([2048, 2048], 0),
        ([2048, 576], 0),
        ([2048, 21888], 1),
        ([10944, 2048], 0),
        ([2048, 2816], 1),
        ([1408, 2048], 0),
    ],
}


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
def sglang_scaled_fp8_quant(
    input: torch.Tensor,
    scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    fp8_type_: torch.dtype = torch.float8_e4m3fn
    output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
    is_static = True
    if scale is None:
        scale = torch.zeros(1, device=input.device, dtype=torch.float32)
        is_static = False
    sgl_per_tensor_quant_fp8(input, output, scale, is_static)

    return output, scale


105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# CI environment uses simplified parameters
if IS_CI:
    batch_sizes = [1]  # Single batch size for CI
else:
    batch_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048]

# Filter line_vals based on vLLM availability
if VLLM_AVAILABLE:
    line_vals = [
        "vllm-fp8-fp16",
        "vllm-fp8-bf16",
        "sglang-fp8-fp16",
        "sglang-fp8-bf16",
    ]
    line_names = [
        "vllm-fp8-fp16",
        "vllm-fp8-bf16",
        "sglang-fp8-fp16",
        "sglang-fp8-bf16",
    ]
    styles = [("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")]
else:
    line_vals = [
        "sglang-fp8-fp16",
        "sglang-fp8-bf16",
    ]
    line_names = [
        "sglang-fp8-fp16",
        "sglang-fp8-bf16",
    ]
    styles = [("blue", "-"), ("blue", "--")]


138
139
140
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
141
        x_vals=batch_sizes,
142
143
        x_log=False,
        line_arg="provider",
144
145
146
        line_vals=line_vals,
        line_names=line_names,
        styles=styles,
147
148
149
150
151
152
153
154
155
156
        ylabel="GB/s",
        plot_name="fp8 scaled matmul",
        args={},
    )
)
def benchmark(batch_size, provider, N, K):
    # M, N, K = batch_size, 4096, 8192
    M = batch_size
    a = torch.ones((M, K), device="cuda") * 5.0
    b = torch.ones((N, K), device="cuda") * 5.0
157
158
159
    # vLLM expects scalar scales, while sglang can handle per-token scales
    scale_a_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
    scale_b_scalar = torch.randn(1, device="cuda", dtype=torch.float32)
160
161
162
163
164
165
166
    scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
    scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]

    dtype = torch.float16 if "fp16" in provider else torch.bfloat16

    if "vllm-fp8" in provider:
167
168
169
170
171
        if not VLLM_AVAILABLE:
            # Return zero if vLLM is not available
            return (0, 0, 0)
        a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_scalar)
        b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b_scalar)
172
        b_fp8 = b_fp8.t()
173
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
174
175
176
177
            lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
            quantiles=quantiles,
        )
    elif "sglang-fp8" in provider:
178
179
180
        a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
        b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
        b_fp8 = b_fp8.t()
181
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            lambda: sgl_scaled_mm(
                a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
            ),
            quantiles=quantiles,
        )

    gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


def prepare_shapes(args):
    KN_model_names = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        assert model in WEIGHT_SHAPES
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KN.append(model)
            KN_model_names.append(KN)
    return KN_model_names


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=["meta-llama/Llama-3.1-8B-Instruct"],
        help="List of models to benchmark",
    )
    parser.add_argument(
        "--tp-sizes",
        nargs="+",
        type=int,
        default=[1],
        help="List of tensor parallel sizes",
    )
    args = parser.parse_args()

222
223
224
225
226
    # Simplify for CI environment
    if IS_CI:
        args.models = [args.models[0]]  # Use only first model
        args.tp_sizes = [args.tp_sizes[0]]  # Use only first TP size

227
228
229
    KN_model_names = prepare_shapes(args)
    for K, N, model_name in KN_model_names:
        print(f"{model_name} N={N} K={K}: ")
230
        benchmark.run(print_data=True, N=N, K=K)
231
232

    print("Benchmark finished!")