bench_per_token_quant_fp8.py 6.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import itertools
from typing import Optional, Tuple

import torch
import triton
import triton.testing
from sgl_kernel import sgl_per_token_quant_fp8
from vllm import _custom_ops as ops

from sglang.srt.utils import is_hip

12
13
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Get correct FP8 E4M3 maximum value
if _is_hip:
    FP8_E4M3_MAX = 224.0  # ROCM uses 224.0
else:
    # For CUDA, get the actual max value from the type
    FP8_E4M3_MAX = float(torch.finfo(fp8_type_).max)


def torch_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Pure PyTorch reference implementation for per-token FP8 quantization."""
    device = input.device
    dtype = input.dtype

    # Find max absolute value per token (row) - exactly like CUDA kernel
    max_vals = torch.abs(input).max(dim=1)[0]  # [num_tokens]

    # Calculate scale per token - exactly like CUDA kernel: scale = max_value / FP8_E4M3_MAX
    scales = max_vals / FP8_E4M3_MAX  # [num_tokens]

    # No special zero handling - directly compute 1.0 / scale like CUDA kernel
    scale_inv = 1.0 / scales  # [num_tokens]

    # Quantize: input * scale_inv, then clamp to FP8 range
    quantized_float = input * scale_inv.unsqueeze(1)  # Broadcast scale_inv
    quantized_float = torch.clamp(quantized_float, -FP8_E4M3_MAX, FP8_E4M3_MAX)

    # Convert to FP8 - use more explicit conversion
    quantized_fp8 = quantized_float.to(fp8_type_)

    return quantized_fp8, scales

48
49
50
51
52
53
54
55
56
57

def vllm_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)


def sglang_per_token_quant_fp8(
    input: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
Yineng Zhang's avatar
Yineng Zhang committed
58
    scale = torch.zeros(input.size(0), device=input.device, dtype=torch.float32)
59
60
    output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
    sgl_per_token_quant_fp8(input, output, scale)
Yineng Zhang's avatar
Yineng Zhang committed
61

62
63
64
    return output, scale


65
66
def calculate_diff(batch_size: int, seq_len: int, hidden_dim: int):
    """Compare Torch reference, VLLM, and SGLang implementations."""
67
    device = torch.device("cuda")
68
69
70
    x = torch.rand(
        (batch_size * seq_len, hidden_dim), dtype=torch.float16, device=device
    )
71

72
73
    # Get all three implementations
    torch_out, torch_scale = torch_per_token_quant_fp8(x)
74
75
76
    vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
    sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)

77
    print(f"\n=== Comparison for hidden_dim={hidden_dim} ===")
Yineng Zhang's avatar
Yineng Zhang committed
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    # Compare scales
    torch_vllm_scale_diff = torch.abs(torch_scale - vllm_scale).mean().item()
    torch_sglang_scale_diff = torch.abs(torch_scale - sglang_scale).mean().item()
    vllm_sglang_scale_diff = torch.abs(vllm_scale - sglang_scale).mean().item()

    print(f"Scale differences:")
    print(f"  Torch vs VLLM:   {torch_vllm_scale_diff:.8f}")
    print(f"  Torch vs SGLang: {torch_sglang_scale_diff:.8f}")
    print(f"  VLLM vs SGLang:  {vllm_sglang_scale_diff:.8f}")

    # Compare outputs
    torch_vllm_out_diff = torch.abs(torch_out.float() - vllm_out.float()).mean().item()
    torch_sglang_out_diff = (
        torch.abs(torch_out.float() - sglang_out.float()).mean().item()
    )
    vllm_sglang_out_diff = (
        torch.abs(vllm_out.float() - sglang_out.float()).mean().item()
    )

    print(f"Output differences:")
    print(f"  Torch vs VLLM:   {torch_vllm_out_diff:.8f}")
    print(f"  Torch vs SGLang: {torch_sglang_out_diff:.8f}")
    print(f"  VLLM vs SGLang:  {vllm_sglang_out_diff:.8f}")

    # Check tolerances
    rtol, atol = 1e-3, 1e-5

    torch_vllm_match = torch.allclose(
        torch_out.float(), vllm_out.float(), rtol=rtol, atol=atol
    ) and torch.allclose(torch_scale, vllm_scale, rtol=rtol, atol=atol)
    torch_sglang_match = torch.allclose(
        torch_out.float(), sglang_out.float(), rtol=rtol, atol=atol
    ) and torch.allclose(torch_scale, sglang_scale, rtol=rtol, atol=atol)

    if hidden_dim == 1368:
        rtol = 1e-2
        # we found vllm sglang has diff when hidden dim is not dividable by 16
        # and we believe SGLang is closer to Torch implementation

    vllm_sglang_match = torch.allclose(
        vllm_out.float(), sglang_out.float(), rtol=rtol, atol=atol
    ) and torch.allclose(vllm_scale, sglang_scale, rtol=rtol, atol=atol)

    print(f"Matches (rtol={rtol}, atol={atol}):")
    print(f"  Torch vs VLLM:   {'✅' if torch_vllm_match else '❌'}")
    print(f"  Torch vs SGLang: {'✅' if torch_sglang_match else '❌'}")
    print(f"  VLLM vs SGLang:  {'✅' if vllm_sglang_match else '❌'}")
126
127
128
129


batch_size_range = [16, 32, 64, 128]
seq_len_range = [64, 128, 256, 512, 1024, 2048, 4096]
130
hidden_dim_range = [1368, 2048, 4096]
131

132
configs = list(itertools.product(batch_size_range, seq_len_range, hidden_dim_range))
133
134
135
136


@triton.testing.perf_report(
    triton.testing.Benchmark(
137
        x_names=["batch_size", "seq_len", "hidden_dim"],
138
139
        x_vals=configs,
        line_arg="provider",
140
141
142
        line_vals=["torch", "vllm", "sglang"],
        line_names=["Torch Reference", "VLLM", "SGL Kernel"],
        styles=[("red", "-"), ("blue", "-"), ("green", "-")],
143
144
145
146
147
        ylabel="us",
        plot_name="per-token-dynamic-quant-fp8-performance",
        args={},
    )
)
148
def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
149
150
151
    dtype = torch.float16
    device = torch.device("cuda")

152
    x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=dtype)
153
154
155

    quantiles = [0.5, 0.2, 0.8]

156
157
158
    if provider == "torch":
        fn = lambda: torch_per_token_quant_fp8(x.clone())
    elif provider == "vllm":
159
160
161
162
        fn = lambda: vllm_per_token_quant_fp8(x.clone())
    elif provider == "sglang":
        fn = lambda: sglang_per_token_quant_fp8(x.clone())

163
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
164
165
166
167
168

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
169
170
171
172
173
174
175
176
    # Test various hidden dimensions for correctness
    test_dims = [1368, 2048, 4096]

    for dim in test_dims:
        calculate_diff(batch_size=4, seq_len=4096, hidden_dim=dim)

    print("\n" + "=" * 60)
    print("Starting performance benchmark...")
177
    benchmark_quantization.run(print_data=True)