bench_per_token_quant_fp8.py 7.82 KB
Newer Older
1
import itertools
2
import os
3
4
5
6
7
8
from typing import Optional, Tuple

import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8
9
10
11
12
13
14
15
16
17

# Optional vLLM import
try:
    from vllm import _custom_ops as ops

    VLLM_AVAILABLE = True
except ImportError:
    ops = None
    VLLM_AVAILABLE = False
18
19
20

from sglang.srt.utils import is_hip

21
_is_hip = is_hip()
22
23
24
25
26
27
28

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)

29
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Get correct FP8 E4M3 maximum value
if _is_hip:
    FP8_E4M3_MAX = 224.0  # ROCM uses 224.0
else:
    # For CUDA, get the actual max value from the type
    FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max)


def torch_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Pure PyTorch reference implementation for per-token FP8 quantization."""
    device = input.device
    dtype = input.dtype

    # Find max absolute value per token (row) - exactly like CUDA kernel
    max_vals = torch.abs(input).max(dim=1)[0]  # [num_tokens]

    # Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
    scales = max_vals / FP8_E4M3_MAX  # [num_tokens]

    # No special zero handling - directly compute 1.0 / scale like CUDA kernel
    scale_inv = 1.0 / scales  # [num_tokens]

    # Quantize: input * scale_inv, then clamp to FP8 range
    quantized_float = input * scale_inv.unsqueeze(1)  # Broadcast scale_inv
    quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX)

    # Convert to FP8 - use more explicit conversion
    quantized_fp8 = quantized_float.to(fp8_type_)

    return quantized_fp8, scales

64
65
66
67

def vllm_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
68
69
70
    if not VLLM_AVAILABLE:
        # Fallback to SGLang implementation
        return sglang_per_token_quant_fp8(input)
71
72
73
74
75
76
    return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)


def sglang_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
Yineng Zhang's avatar
Yineng Zhang committed
77
    scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
78
79
    output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
    sgl_per_token_quant_fp8(input, output, scale)
Yineng Zhang's avatar
Yineng Zhang committed
80

81
82
83
    return output, scale


84
85
def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
    """Compare Torch reference, VLLM, and SGLang implementations."""
86
    device = torch.device("cuda")
87
88
89
    x = torch.rand(
        (batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device
    )
90

91
92
    # Get all three implementations
    torch_out, torch_scale = torch_per_token_quant_fp8(x)
93
94
95
    vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
    sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)

96
97
98
99
100
101
102
103
104
105
106
    if not VLLM_AVAILABLE:
        print("⚠️ vLLM not available, skipping vLLM comparison")
        # Only compare Torch vs SGLang
        torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
        torch_sglang_out_diff = (
            torch.abs(torch_out.float() - sglang_out.float()).mean().item()
        )
        print(f"Scale difference (Torch vs SGLang): {torch_sglang_scale_diff:.8f}")
        print(f"Output difference (Torch vs SGLang): {torch_sglang_out_diff:.8f}")
        return

107
    print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
Yineng Zhang's avatar
Yineng Zhang committed
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    # Compare scales
    torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item()
    torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
    vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()

    print(f"Scale differences:")
    print(f"  Torch vs VLLM:   {torch_vllm_scale_diff:.8f}")
    print(f"  Torch vs SGLang: {torch_sglang_scale_diff:.8f}")
    print(f"  VLLM vs SGLang:  {vllm_sglang_scale_diff:.8f}")

    # Compare outputs
    torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item()
    torch_sglang_out_diff = (
        torch.abs(torch_out.float() - sglang_out.float()).mean().item()
    )
    vllm_sglang_out_diff = (
        torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
    )

    print(f"Output differences:")
    print(f"  Torch vs VLLM:   {torch_vllm_out_diff:.8f}")
    print(f"  Torch vs SGLang: {torch_sglang_out_diff:.8f}")
    print(f"  VLLM vs SGLang:  {vllm_sglang_out_diff:.8f}")

    # Check tolerances
    rtol, atol = 1e-3, 1e-5

    torch_vllm_match = torch.allclose(
        torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol
    ) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol)
    torch_sglang_match = torch.allclose(
        torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol
    ) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol)

    if hidden_dim == 1368:
        rtol = 1e-2
        # we found vllm sglang has diff when hidden dim is not dividable by 16
        # and we believe SGLang is closer to Torch implementation

    vllm_sglang_match = torch.allclose(
        vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol
    ) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol)

    print(f"Matches (rtol={rtol}, atol={atol}):")
    print(f"  Torch vs VLLM:   {'✅' if torch_vllm_match else '❌'}")
    print(f"  Torch vs SGLang: {'✅' if torch_sglang_match else '❌'}")
    print(f"  VLLM vs SGLang:  {'✅' if vllm_sglang_match else '❌'}")
156
157


158
159
160
161
162
163
164
165
166
# CI environment uses simplified parameters
if IS_CI:
    batch_size_range = [16]  # Single batch size for CI
    seq_len_range = [64]  # Single sequence length for CI
    hidden_dim_range = [2048]  # Single hidden dimension for CI
else:
    batch_size_range = [16, 32, 64, 128]
    seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
    hidden_dim_range = [1368, 2048, 4096]
167

168
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
169
170
171
172


@triton.testing.perf_report(
    triton.testing.Benchmark(
173
        x_names=["batch_size", "seq_len", "hidden_dim"],
174
175
        x_vals=configs,
        line_arg="provider",
176
177
178
179
180
181
182
183
184
185
186
187
188
        line_vals=(
            ["torch", "vllm", "sglang"] if VLLM_AVAILABLE else ["torch", "sglang"]
        ),
        line_names=(
            ["Torch Reference", "VLLM", "SGL Kernel"]
            if VLLM_AVAILABLE
            else ["Torch Reference", "SGL Kernel"]
        ),
        styles=(
            [("red", "-"), ("blue", "-"), ("green", "-")]
            if VLLM_AVAILABLE
            else [("red", "-"), ("green", "-")]
        ),
189
190
191
192
193
        ylabel="us",
        plot_name="per-token-dynamic-quant-fp8-performance",
        args={},
    )
)
194
def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
195
196
197
    dtype = torch.float16
    device = torch.device("cuda")

198
    x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype)
199
200
201

    quantiles = [0.5, 0.2, 0.8]

202
203
204
    if provider == "torch":
        fn = lambda: torch_per_token_quant_fp8(x.clone())
    elif provider == "vllm":
205
206
        if not VLLM_AVAILABLE:
            return (0, 0, 0)
207
208
209
210
        fn = lambda: vllm_per_token_quant_fp8(x.clone())
    elif provider == "sglang":
        fn = lambda: sglang_per_token_quant_fp8(x.clone())

211
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
212
213
214
215
216

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
217
218
219
220
221
222
223
    # Test various hidden dimensions for correctness - simplified for CI
    if IS_CI:
        test_dims = [2048]  # Single dimension for CI
        batch_size, seq_len = 4, 64  # Smaller values for CI
    else:
        test_dims = [1368, 2048, 4096]
        batch_size, seq_len = 4, 4096
224
225

    for dim in test_dims:
226
        calculate_diff(batch_size=batch_size, seq_len=seq_len, hidden_dim=dim)
227
228
229

    print("\n" + "=" * 60)
    print("Starting performance benchmark...")
230
    benchmark_quantization.run(print_data=True)