bench_int8_gemm.py 5.1 KB
Newer Older
1
2
3
import argparse
import copy
import itertools
4
import os
5

Ke Bao's avatar
Ke Bao committed
6
7
8
import torch
import triton
from sgl_kernel import int8_scaled_mm
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

# Optional vLLM import
try:
    from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm

    VLLM_AVAILABLE = True
except ImportError:
    vllm_scaled_mm = None
    VLLM_AVAILABLE = False

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
Ke Bao's avatar
Ke Bao committed
24
25
26
27
28
29


def to_int8(tensor: torch.Tensor) -> torch.Tensor:
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
WEIGHT_SHAPES = {
    "meta-llama/Llama-3.1-8B-Instruct": [
        ([4096, 6144], 1),
        ([4096, 4096], 0),
        ([4096, 28672], 1),
        ([14336, 4096], 0),
    ],
    "meta-llama/Llama-3.3-70B-Instruct": [
        ([8192, 10240], 1),
        ([8192, 8192], 0),
        ([8192, 57344], 1),
        ([28672, 8192], 0),
    ],
    "mistralai/Mistral-Large-Instruct-2407": [
        ([12288, 14336], 1),
        ([12288, 12288], 0),
        ([12288, 57344], 1),
        ([28672, 12288], 0),
    ],
    "Qwen/Qwen2.5-7B-Instruct": [
        ([3584, 4608], 1),
        ([3584, 3584], 0),
        ([3584, 37888], 1),
        ([18944, 3584], 0),
    ],
    "Qwen/Qwen2.5-32B-Instruct": [
        ([5120, 7168], 1),
        ([5120, 5120], 0),
        ([5120, 55296], 1),
        ([27648, 5120], 0),
    ],
    "Qwen/Qwen2.5-72B-Instruct": [
        ([8192, 10240], 1),
        ([8192, 8192], 0),
        ([8192, 59136], 1),
        ([29568, 8192], 0),
    ],
    "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
        ([2048, 3072], 1),
        ([2048, 4096], 1),
        ([2048, 2048], 0),
        ([2048, 576], 0),
        ([2048, 21888], 1),
        ([10944, 2048], 0),
        ([2048, 2816], 1),
        ([1408, 2048], 0),
    ],
}


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# CI environment uses simplified parameters
if IS_CI:
    batch_sizes = [1]  # Single batch size for CI
else:
    batch_sizes = [1, 16, 32, 64, 128, 256, 512, 1024, 2048]

# Filter providers based on vLLM availability
if VLLM_AVAILABLE:
    line_vals = ["vllm", "sgl-kernel"]
    line_names = ["vllm int8 gemm", "sgl-kernel int8 gemm"]
    styles = [("blue", "-"), ("orange", "-")]
else:
    line_vals = ["sgl-kernel"]
    line_names = ["sgl-kernel int8 gemm"]
    styles = [("orange", "-")]


Ke Bao's avatar
Ke Bao committed
97
98
99
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size"],
100
        x_vals=batch_sizes,
Ke Bao's avatar
Ke Bao committed
101
102
        x_log=False,
        line_arg="provider",
103
104
105
        line_vals=line_vals,
        line_names=line_names,
        styles=styles,
Ke Bao's avatar
Ke Bao committed
106
107
108
109
110
        ylabel="GB/s",
        plot_name="int8 scaled matmul",
        args={},
    )
)
111
112
def benchmark(batch_size, provider, N, K):
    M = batch_size
Ke Bao's avatar
Ke Bao committed
113
114
115
116
117
118
119
120
    a = to_int8(torch.randn((M, K), device="cuda") * 5)
    b = to_int8(torch.randn((N, K), device="cuda").t() * 5)
    scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
    scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
    bias = torch.randn((N,), device="cuda", dtype=torch.float16)

    quantiles = [0.5, 0.2, 0.8]
    if provider == "sgl-kernel":
121
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
Ke Bao's avatar
Ke Bao committed
122
123
124
            lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
            quantiles=quantiles,
        )
125
126
127
    elif provider == "vllm":
        if not VLLM_AVAILABLE:
            return (0, 0, 0)
128
        ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
Ke Bao's avatar
Ke Bao committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
            quantiles=quantiles,
        )
    gbps = (
        lambda ms: (
            (2 * M * N * K - M * N) * a.element_size()
            + (3 * M * N) * scale_a.element_size()
        )
        * 1e-9
        / (ms * 1e-3)
    )
    return gbps(ms), gbps(max_ms), gbps(min_ms)


143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def prepare_shapes(args):
    KN_model_names = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        assert model in WEIGHT_SHAPES
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KN.append(model)
            KN_model_names.append(KN)
    return KN_model_names


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=["meta-llama/Llama-3.1-8B-Instruct"],
        help="List of models to benchmark",
    )
    parser.add_argument(
        "--tp-sizes",
        nargs="+",
        type=int,
        default=[1],
        help="List of tensor parallel sizes",
    )
    args = parser.parse_args()

173
174
175
176
177
178
179
180
181
182
183
    # Skip in CI environment due to architecture compatibility issues
    if IS_CI:
        print(
            "Skipping INT8 GEMM benchmark in CI environment due to architecture compatibility issues"
        )
        print("INT8 operations may not be supported on all GPU architectures")
    else:
        KN_model_names = prepare_shapes(args)
        for K, N, model_name in KN_model_names:
            print(f"{model_name} N={N} K={K}: ")
            benchmark.run(print_data=True, N=N, K=K)
184

185
        print("Benchmark finished!")