bench_per_token_quant_fp8.py 7.96 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Callable
5
from unittest.mock import patch
6

7
import pandas as pd
8
9
10
11
12
import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
13
14
15
16
17
18
19
20
21
22
23
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


def with_triton_mode(fn):
    """Temporarily force the Triton fallback path"""

    def wrapped(*args, **kwargs):
        with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
            return fn(*args, **kwargs)

    return wrapped
24
25
26
27
28
29
30
31
32
33
34


# TODO(luka): use standalone_compile utility
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
    def inner(*args):
        torch._dynamo.mark_dynamic(args[arg_index], dim_index)
        return fn(*args)

    return inner


35
36
37
def bench_compile(fn: Callable):
    # recompile for different shapes
    fwd = torch.compile(fn, fullgraph=True, dynamic=False)
38
39

    # First dim is explicitly dynamic to simulate vLLM usage
40
    return with_dyn_arg(fwd, 0, 0)
41
42


43
torch._dynamo.config.recompile_limit = 8888
44
45


46
47
48
49
50
51
52
def calculate_diff(
    batch_size: int,
    hidden_size: int,
    group_shape: GroupShape,
    dtype: torch.dtype,
):
    """Calculate the difference between Inductor and CUDA implementations."""
53
    device = torch.device("cuda")
54
    x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device)
55
56

    quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
57

58
59
60
    torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
    torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
    cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    try:
        torch.testing.assert_close(
            cuda_out.to(torch.float32),
            torch_out.to(torch.float32),
            rtol=1e-3,
            atol=1e-5,
        )
        torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5)
        torch.testing.assert_close(
            cuda_out.to(torch.float32),
            torch_eager_out.to(torch.float32),
            rtol=1e-3,
            atol=1e-5,
        )
        torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5)
77
        print("✅ All implementations match")
78
    except AssertionError as e:
79
        print("❌ Implementations differ")
80
        print(e)
81
82


83
configs = []
84
85


86
87
88
89
90
91
92
93
def benchmark_quantization(
    batch_size,
    hidden_size,
    provider,
    group_shape: GroupShape,
    col_major: bool,
    dtype: torch.dtype,
):
94
95
    device = torch.device("cuda")

96
    x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype)
97
98

    quantiles = [0.5, 0.2, 0.8]
99
    quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
100
101

    if provider == "torch":
102
        fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
103
    elif provider == "cuda":
104
105
106
107
108
109
110
        fn = lambda: quant_fp8.forward_cuda(x.clone())
    elif provider == "triton":
        if not group_shape.is_per_group():
            # Triton only supported for per-group
            return 0, 0, 0

        fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())
111
112
113
114
115
116

    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# TODO(luka) extract to utils
def compute_geomean_speedups(
    df: pd.DataFrame,
    baseline_col: str,
    speedup_cols: list[str],
    groupby_cols: list[str] | None = None,
) -> pd.DataFrame:
    """
    Compute geometric mean speedups over a baseline column.

    Args:
        df: Input dataframe
        baseline_col: Column to use as baseline
        speedup_cols: Columns to compute speedups for
        groupby_cols: Columns to group by. If None, compute over entire df.

    Returns:
        pd.DataFrame with geometric mean speedups
    """
    from scipy.stats import gmean

    def geo_speedup(group: pd.DataFrame) -> pd.Series:
        ratios = {
            col: (group[baseline_col] / group[col]).values for col in speedup_cols
        }
        return pd.Series({col: gmean(vals) for col, vals in ratios.items()})

    if groupby_cols is None:
        result = geo_speedup(df).to_frame().T
    else:
        result = (
            df.groupby(groupby_cols)
            .apply(geo_speedup, include_groups=False)
            .reset_index()
        )

    return result


156
if __name__ == "__main__":
157
158
159
160
161
    parser = FlexibleArgumentParser(
        description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
    )
    parser.add_argument("-c", "--check", action="store_true")
    parser.add_argument(
162
        "--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
163
164
165
166
167
    )
    parser.add_argument(
        "--hidden-sizes",
        type=int,
        nargs="+",
168
169
        default=[896, 1024, 2048, 4096, 7168],
        help="Hidden sizes to benchmark",
170
171
172
173
174
    )
    parser.add_argument(
        "--batch-sizes",
        type=int,
        nargs="+",
175
176
        default=[1, 16, 128, 512, 1024],
        help="Batch sizes to benchmark",
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    )
    parser.add_argument(
        "--group-sizes",
        type=int,
        nargs="+",
        default=None,
        help="Group sizes for GroupShape(1,N) to benchmark. "
        "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)",
    )
    parser.add_argument(
        "--no-column-major",
        action="store_true",
        help="Disable column-major scales testing",
    )

    args = parser.parse_args()
    assert args

    dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]

197
198
    hidden_sizes = args.hidden_sizes
    batch_sizes = args.batch_sizes
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269

    if args.group_sizes is not None:
        group_shapes = []
        for size in args.group_sizes:
            if size == 0:
                group_shapes.append(GroupShape.PER_TENSOR)
            elif size == -1:
                group_shapes.append(GroupShape.PER_TOKEN)
            else:
                group_shapes.append(GroupShape(1, size))
    else:
        group_shapes = [
            GroupShape.PER_TENSOR,
            GroupShape.PER_TOKEN,
            GroupShape(1, 64),
            GroupShape(1, 128),
        ]

    column_major_scales = [False] if args.no_column_major else [True, False]

    config_gen = itertools.product(
        group_shapes,
        column_major_scales,
        batch_sizes,
        hidden_sizes,
    )

    # filter out column-major scales for non-group, reverse order
    configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))

    print(f"Running {len(configs)} configurations:")
    print(f"  Hidden sizes: {hidden_sizes}")
    print(f"  Batch sizes: {batch_sizes}")
    print(f"  Group shapes: {[str(g) for g in group_shapes]}")
    print(f"  Column major scales: {column_major_scales}")
    print()

    if args.check:
        for group_shape in group_shapes:
            group_size = group_shape[1]
            print(f"{group_size=}")
            calculate_diff(
                batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
            )

    benchmark = triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
            x_vals=configs,
            line_arg="provider",
            line_vals=["torch", "cuda", "triton"],
            line_names=["Torch (Compiled)", "CUDA", "Triton"],
            styles=[("blue", "-"), ("green", "-"), ("black", "-")],
            ylabel="us",
            plot_name="QuantFP8 performance",
            args={},
        )
    )(benchmark_quantization)

    df = benchmark.run(print_data=True, dtype=dtype, return_df=True)

    # Print geomean speedups
    geo_table_grouped = compute_geomean_speedups(
        df,
        baseline_col="Torch (Compiled)",
        speedup_cols=["CUDA", "Triton"],
        groupby_cols=["col_major", "group_shape"],
    )

    print("Speedup over Torch (Compiled)")
    print(geo_table_grouped.to_string(index=False))