bench_awq_dequant.py 4.42 KB
Newer Older
1
import itertools
2
import os
3
4
5
6
7
8
from typing import List, Tuple

import torch
import triton
import triton.testing
from sgl_kernel import awq_dequantize
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

# Optional vLLM import
try:
    from vllm import _custom_ops as ops

    VLLM_AVAILABLE = True
except ImportError:
    ops = None
    VLLM_AVAILABLE = False

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
24
25
26
27
28


def vllm_awq_dequantize(
    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
29
30
31
    if not VLLM_AVAILABLE:
        # Fallback to SGLang implementation
        return sglang_awq_dequantize(qweight, scales, qzeros)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)


def sglang_awq_dequantize(
    qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

    return awq_dequantize(qweight, scales, qzeros)


def calculate_diff(qweight_row: int, qweight_col: int):
    """Calculate difference between VLLM and SGLang implementations."""
    device = torch.device("cuda")
    qweight = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (qweight_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )
    group_size = qweight_row
    scales_row = qweight_row // group_size
    scales_col = qweight_col * 8
    scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
    qzeros = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (scales_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )

64
65
66
67
    if not VLLM_AVAILABLE:
        print("⚠️ vLLM not available, skipping comparison")
        return

68
69
70
71
72
73
74
75
76
77
78
79
80
    vllm_out = vllm_awq_dequantize(qweight, scales, qzeros)
    sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)

    output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item()

    if torch.allclose(
        vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
    ):
        print("✅ All implementations match")
    else:
        print("❌ Implementations differ")


81
82
83
84
85
86
87
# CI environment uses simplified parameters
if IS_CI:
    qweight_row_range = [128]  # Single row size for CI
    qweight_cols_range = [16]  # Single column size for CI
else:
    qweight_row_range = [3584, 18944, 128, 256, 512, 1024]
    qweight_cols_range = [448, 576, 4736, 16, 32, 64, 128]
88
89
90
91
92
93
94
95
96

configs = list(itertools.product(qweight_row_range, qweight_cols_range))


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["qweight_row", "qweight_col"],
        x_vals=configs,
        line_arg="provider",
97
98
99
        line_vals=["vllm", "sglang"] if VLLM_AVAILABLE else ["sglang"],
        line_names=["VLLM", "SGL Kernel"] if VLLM_AVAILABLE else ["SGL Kernel"],
        styles=[("blue", "-"), ("green", "-")] if VLLM_AVAILABLE else [("green", "-")],
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        ylabel="us",
        plot_name="awq-dequantize-performance",
        args={},
    )
)
def benchmark(qweight_row, qweight_col, provider):
    dtype = torch.float16
    device = torch.device("cuda")
    qweight = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (qweight_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )
    group_size = qweight_row
    scales_row = qweight_row // group_size
    scales_col = qweight_col * 8
    scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
    qzeros = torch.randint(
        0,
        torch.iinfo(torch.int32).max,
        (scales_row, qweight_col),
        dtype=torch.int32,
        device=device,
    )

    quantiles = [0.5, 0.2, 0.8]

    if provider == "vllm":
130
131
        if not VLLM_AVAILABLE:
            return (0, 0, 0)
132
133
134
135
136
137
138
139
        fn = lambda: vllm_awq_dequantize(
            qweight.clone(), scales.clone(), qzeros.clone()
        )
    elif provider == "sglang":
        fn = lambda: sglang_awq_dequantize(
            qweight.clone(), scales.clone(), qzeros.clone()
        )

140
    ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
141
142
143
144
145

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
146
147
148
149
150
151
152
    # Simplify for CI environment
    if IS_CI:
        qweight_row, qweight_col = 128, 16  # Smaller values for CI
    else:
        qweight_row, qweight_col = 3584, 448

    calculate_diff(qweight_row=qweight_row, qweight_col=qweight_col)
153
    benchmark.run(print_data=True)