generate_kernels.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
import glob
import itertools
import os
import subprocess
7
import sys
8
9
10

import jinja2

11
12
ARCHS = []
SUPPORT_FP8 = False
13
14
SUPPORT_SM75 = False
SUPPORT_SM80 = False
15
16
17
18
19
20
21
22
23
for arch in sys.argv[1].split(","):
    arch = arch[: arch.index(".") + 2].replace(".", "")
    arch = int(arch)
    # only SM89 and SM120 fully support
    # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32.
    # SM90 and SM100 can use this PTX, but it’s simulated
    # with FP16 MMA, so it cannot achieve any acceleration.
    if arch in [89, 120]:
        SUPPORT_FP8 = True
24
25
26
27
    if arch >= 80:
        SUPPORT_SM80 = True
    if arch == 75:
        SUPPORT_SM75 = True
28
29
30

FILE_HEAD_COMMENT = """
// auto generated by generate_kernels.py
31
// clang-format off
32
""".lstrip()
33

34
35
36
FILE_HEAD = (
    FILE_HEAD_COMMENT
    + """
37
38
39
40
#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
41
42
"""
)
43

44
45
TEMPLATE = (
    "template __global__ void Marlin<"
46
47
48
    "{{a_type_id}}, "
    "{{b_type_id}}, "
    "{{c_type_id}}, "
49
50
51
52
53
    "{{s_type_id}}, "
    "{{threads}}, "
    "{{thread_m_blocks}}, "
    "{{thread_n_blocks}}, "
    "{{thread_k_blocks}}, "
54
    "{{m_block_size_8}}, "
55
56
    "{{stages}}, "
    "{{group_blocks}}, "
57
    "{{is_zp_float}}>"
58
59
    "( MARLIN_KERNEL_PARAMS );"
)
60

61
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
62
63

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

QUANT_CONFIGS = [
    # AWQ-INT4
    {
        "b_type": "kU4",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4
    {
        "b_type": "kU4B8",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 0, 2, 4, 8],
    },
    # GPTQ-INT8
    {
        "b_type": "kU8B128",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 0, 2, 4, 8],
    },
    # FP8
    {
        "b_type": "kFE4M3fn",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [-1, 8],
    },
    # NVFP4
    {
        "b_type": "kFE2M1f",
        "s_type": "kFE4M3fn",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [1],
    },
    # MXFP4
    {
        "a_type": ["kBFloat16"],
        "b_type": "kFE2M1f",
        "s_type": "kFE8M0fnu",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": THREAD_M_BLOCKS,
        "group_blocks": [2],
    },
    # AWQ-INT4 with INT8 activation
    {
        "a_type": ["kS8"],
        "b_type": "kU4",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with INT8 activation
    {
        "a_type": ["kS8"],
        "b_type": "kU4B8",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # GPTQ-INT4 with FP8 activation
    {
        "a_type": ["kFE4M3fn"],
        "b_type": "kU4B8",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # AWQ-INT4 with FP8 activation
    {
        "a_type": ["kFE4M3fn"],
        "b_type": "kU4",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [-1, 2, 4, 8],
    },
    # MXFP4 with FP8 activation
    {
        "a_type": ["kFE4M3fn"],
        "b_type": "kFE2M1f",
        "c_type": ["kBFloat16"],
        "s_type": "kFE8M0fnu",
        "thread_configs": THREAD_CONFIGS,
        "thread_m_blocks": [1, 2, 3, 4],
        "group_blocks": [2],
    },
]
154
155
156


def remove_old_kernels():
157
    for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"):
158
159
        subprocess.call(["rm", "-f", filename])

160
161
162
    filename = os.path.dirname(__file__) + "/kernel_selector.h"
    subprocess.call(["rm", "-f", filename])

163
164

def generate_new_kernels():
165
    result_dict = {}
166
    sm_75_result_dict = {}
167

168
169
170
171
172
173
174
175
176
177
178
    for quant_config in QUANT_CONFIGS:
        c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"])
        a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"])
        b_type = quant_config["b_type"]
        is_zp_float = quant_config.get("is_zp_float", False)
        all_group_blocks = quant_config["group_blocks"]
        all_m_blocks = quant_config["thread_m_blocks"]
        all_thread_configs = quant_config["thread_configs"]

        for a_type, c_type in itertools.product(a_types, c_types):
            if not SUPPORT_FP8 and a_type == "kFE4M3fn":
179
                continue
180
            if "16" in a_type and "16" in c_type and a_type != c_type:
181
                continue
182
183
184
            s_type = quant_config.get("s_type", c_type)
            if (a_type, b_type, c_type) not in result_dict:
                result_dict[(a_type, b_type, c_type)] = []
185
186
                if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16":
                    sm_75_result_dict[(a_type, b_type, c_type)] = []
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
            for group_blocks, m_blocks, thread_configs in itertools.product(
                all_group_blocks, all_m_blocks, all_thread_configs
            ):
                thread_k, thread_n, threads = thread_configs

                if threads == 256:
                    # for small batch (m_blocks == 1),
                    #     we only need (128, 128, 256)
                    # for large batch (m_blocks > 1),
                    #     we only need (64, 256, 256)
                    if m_blocks <= 1 and (thread_k, thread_n) != (128, 128):
                        continue
                    if m_blocks > 1 and (thread_k, thread_n) != (64, 256):
                        continue
202

203
204
205
206
207
208
209
                config = {
                    "threads": threads,
                    "s_type": s_type,
                    "thread_m_blocks": max(m_blocks, 1),
                    "thread_k_blocks": thread_k // 16,
                    "thread_n_blocks": thread_n // 16,
                    "m_block_size_8": "true" if m_blocks == 0.5 else "false",
210
                    "stages": 4,
211
212
213
214
                    "group_blocks": group_blocks,
                    "is_zp_float": "true" if is_zp_float else "false",
                }

215
216
217
218
219
220
                if SUPPORT_SM80:
                    result_dict[(a_type, b_type, c_type)].append(config)
                if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75:
                    config_sm75 = config.copy()
                    config_sm75["stages"] = 2
                    sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75)
221
222
223

    kernel_selector_str = FILE_HEAD_COMMENT

224
225
226
227
228
229
230
231
    for result_dict_tmp in [result_dict, sm_75_result_dict]:
        for (a_type, b_type, c_type), config_list in result_dict_tmp.items():
            all_template_str_list = []
            if not config_list:
                continue
            for config in config_list:
                s_type = config["s_type"]
                template_str = jinja2.Template(TEMPLATE).render(
232
233
234
235
236
237
                    a_type_id=f"vllm::{a_type}.id()",
                    b_type_id=f"vllm::{b_type}.id()",
                    c_type_id=f"vllm::{c_type}.id()",
                    s_type_id=f"vllm::{s_type}.id()",
                    **config,
                )
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
                all_template_str_list.append(template_str)

                conditions = [
                    f"a_type == vllm::{a_type}",
                    f"b_type == vllm::{b_type}",
                    f"c_type == vllm::{c_type}",
                    f"s_type == vllm::{s_type}",
                    f"threads == {config['threads']}",
                    f"thread_m_blocks == {config['thread_m_blocks']}",
                    f"thread_n_blocks == {config['thread_n_blocks']}",
                    f"thread_k_blocks == {config['thread_k_blocks']}",
                    f"m_block_size_8 == {config['m_block_size_8']}",
                    f"stages == {config['stages']}",
                    f"group_blocks == {config['group_blocks']}",
                    f"is_zp_float == {config['is_zp_float']}",
                ]
                conditions = " && ".join(conditions)

                if kernel_selector_str == FILE_HEAD_COMMENT:
                    kernel_selector_str += f"if ({conditions})\n  kernel = "
                else:
                    kernel_selector_str += f"else if ({conditions})\n  kernel = "

                kernel_template2 = (
                    "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, "
                    "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, "
                    "{{thread_n_blocks}}, {{thread_k_blocks}}, "
                    "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, "
                    "{{is_zp_float}}>;"
                )
268

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                kernel_selector_str += (
                    jinja2.Template(kernel_template2).render(
                        a_type_id=f"vllm::{a_type}.id()",
                        b_type_id=f"vllm::{b_type}.id()",
                        c_type_id=f"vllm::{c_type}.id()",
                        s_type_id=f"vllm::{s_type}.id()",
                        **config,
                    )
                    + "\n"
                )

            file_content = FILE_HEAD + "\n\n"
            file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
            if a_type == "kFE4M3fn":
                filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
            elif result_dict_tmp is sm_75_result_dict:
                filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
            else:
                filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu"
288

289
            filename = filename.lower()
290

291
292
            with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
                f.write(file_content)
293

294
295
296
297
298
299
300
301
302
303
    if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT:
        kernel_selector_str += (
            "else if (a_type == vllm::kFE4M3fn)\n"
            "  TORCH_CHECK(false, "
            '"marlin kernel with fp8 activation is not built.");'
        )

    with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f:
        f.write(kernel_selector_str)

304
305
306
307

if __name__ == "__main__":
    remove_old_kernels()
    generate_new_kernels()