"examples/vscode:/vscode.git/clone" did not exist on "5a16fa614c78e1f401125cd7384c602f83cb2160"
bench_nvfp4_quant.py 6.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import itertools

import torch
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.triton_utils import triton
from vllm.utils.flashinfer import flashinfer_fp4_quantize

if not current_platform.has_device_capability(100):
    raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")

FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

PROVIDER_CFGS = {
23
24
25
26
27
28
    "vllm": dict(backend="vllm", is_sf_swizzled_layout=False, enabled=True),
    "vllm-swizzle": dict(backend="vllm", is_sf_swizzled_layout=True, enabled=True),
    "flashinfer": dict(backend="flashinfer", is_sf_swizzled_layout=False, enabled=True),
    "flashinfer-swizzle": dict(
        backend="flashinfer", is_sf_swizzled_layout=True, enabled=True
    ),
29
30
31
32
33
34
35
36
37
38
39
40
41
42
}

_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]


def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
    """Compute global scale for FP4 quantization."""
    amax = torch.abs(tensor).max().to(torch.float32)
    return FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
43
        x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
        x_log=False,
        line_arg="provider",
        line_vals=_enabled,
        line_names=_enabled,
        ylabel="us (lower is better)",
        plot_name="NVFP4 Input Quantization Latency (us)",
        args={},
    )
)
def benchmark(batch_size, provider, N, K):
    M = batch_size
    device = "cuda"
    dtype = torch.bfloat16

    # Create input tensor
    a = torch.randn((M, K), device=device, dtype=dtype)

    # Compute global scale for activation
    a_global_scale = compute_global_scale(a)

    quantiles = [0.5, 0.2, 0.8]

    cfg = PROVIDER_CFGS[provider]

    if cfg["backend"] == "vllm":
        # vLLM's FP4 quantization
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        if cfg["is_sf_swizzled_layout"]:
            ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
                lambda: ops.scaled_fp4_quant(
                    a, a_global_scale, is_sf_swizzled_layout=True
                ),
                quantiles=quantiles,
            )
        else:
            ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
                lambda: ops.scaled_fp4_quant(
                    a, a_global_scale, is_sf_swizzled_layout=False
                ),
                quantiles=quantiles,
            )
84
85
    elif cfg["backend"] == "flashinfer":
        # FlashInfer's FP4 quantization
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        if cfg["is_sf_swizzled_layout"]:
            ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
                lambda: flashinfer_fp4_quantize(
                    a, a_global_scale, is_sf_swizzled_layout=True
                ),
                quantiles=quantiles,
            )
        else:
            ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
                lambda: flashinfer_fp4_quantize(
                    a, a_global_scale, is_sf_swizzled_layout=False
                ),
                quantiles=quantiles,
            )
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

    # Convert ms to us for better readability at small batch sizes
    to_us = lambda t_ms: t_ms * 1000
    return to_us(ms), to_us(max_ms), to_us(min_ms)


def prepare_shapes(args):
    out = []
    for model, tp_size in itertools.product(args.models, args.tp_sizes):
        for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
            KN[tp_dim] //= tp_size
            KN.append(model)
            out.append(KN)
    return out


116
117
118
def _test_accuracy_once(
    M: int, K: int, dtype: torch.dtype, device: str, is_sf_swizzled_layout: bool
):
119
120
121
122
123
124
125
126
    """Test accuracy between vLLM and FlashInfer FP4 quantization."""
    # Create input tensor
    a = torch.randn((M, K), device=device, dtype=dtype)

    # Compute global scale
    a_global_scale = compute_global_scale(a)

    # vLLM quantization
127
128
129
    vllm_fp4, vllm_scale = ops.scaled_fp4_quant(
        a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
    )
130
131
132

    # FlashInfer quantization (with swizzled layout to match vLLM's output)
    flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize(
133
        a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
134
135
136
137
138
139
140
141
    )
    flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn)

    # Compare outputs
    torch.testing.assert_close(
        vllm_fp4,
        flashinfer_fp4,
    )
142
143
144
145
146
147
148
149
    # Compare scales
    torch.testing.assert_close(
        vllm_scale,
        flashinfer_scale,
    )
    print(
        f"M={M}, K={K}, dtype={dtype}, is_sf_swizzled_layout={is_sf_swizzled_layout}: PASSED"  # noqa: E501
    )
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164


def test_accuracy():
    """Run accuracy tests across various shapes."""
    print("\n" + "=" * 60)
    print("Running accuracy tests: vLLM vs FlashInfer")
    print("=" * 60)

    device = "cuda"
    dtype = torch.bfloat16

    # Test various batch sizes and hidden dimensions
    Ms = [1, 1024]
    Ks = [4096]

165
166
167
168
    for is_sf_swizzled_layout in [True, False]:
        for M in Ms:
            for K in Ks:
                _test_accuracy_once(M, K, dtype, device, is_sf_swizzled_layout)
169
170
171
172
173
174
175
176
177
178
179
180

    print("\nAll accuracy tests passed!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Benchmark NVFP4 quantization: vLLM vs FlashInfer"
    )
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
181
        default=["meta-llama/Llama-3.3-70B-Instruct"],
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        choices=list(WEIGHT_SHAPES.keys()),
    )
    parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
    parser.add_argument(
        "--save-path",
        type=str,
        default=None,
        help="Path to save benchmark results",
    )
    parser.add_argument(
        "--accuracy",
        action="store_true",
        help="Run accuracy tests",
    )
    args = parser.parse_args()

    if args.accuracy:
        test_accuracy()

    for K, N, model in prepare_shapes(args):
        print(f"\n{model}, N={N} K={K}")
        benchmark.run(
            print_data=True,
            save_path=args.save_path,
            N=N,
            K=K,
        )

    print("\nBenchmark finished!")