generate_kernels.py 4.83 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import glob
import itertools
import os
import subprocess

import jinja2

FILE_HEAD = """
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
""".strip()

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
TEMPLATE = (
    "template __global__ void Marlin<"
    "{{scalar_t}}, "
    "{{w_type_id}}, "
    "{{s_type_id}}, "
    "{{threads}}, "
    "{{thread_m_blocks}}, "
    "{{thread_n_blocks}}, "
    "{{thread_k_blocks}}, "
    "{{'true' if m_block_size_8 else 'false'}}, "
    "{{stages}}, "
    "{{group_blocks}}, "
    "{{'true' if is_zp_float else 'false'}}>"
    "( MARLIN_KERNEL_PARAMS );"
)
35
36
37

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
38
SCALAR_TYPES = [
39
40
41
42
43
    "vllm::kU4",
    "vllm::kU4B8",
    "vllm::kU8B128",
    "vllm::kFE4M3fn",
    "vllm::kFE2M1f",
44
]
45
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
46
47
48
49
50
51

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
#   = 0 : act order case
#   = -1 : channelwise quantization
#   > 0 : group_size=16*group_blocks
52
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
53
54
55
56
57
58
59
60
61
62
63
64
65
DTYPES = ["fp16", "bf16"]


def remove_old_kernels():
    for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
        subprocess.call(["rm", "-f", filename])


def generate_new_kernels():
    for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
        all_template_str_list = []

        for group_blocks, m_blocks, thread_configs in itertools.product(
66
67
            GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
        ):
68
69
            # act order case only support gptq-int4 and gptq-int8
            if group_blocks == 0 and scalar_type not in [
70
71
                "vllm::kU4B8",
                "vllm::kU8B128",
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            ]:
                continue
            if thread_configs[2] == 256:
                # for small batch (m_blocks == 1), we only need (128, 128, 256)
                # for large batch (m_blocks > 1), we only need (64, 256, 256)
                if m_blocks <= 1 and thread_configs[0] != 128:
                    continue
                if m_blocks > 1 and thread_configs[0] != 64:
                    continue

            # we only support channelwise quantization and group_size == 128
            # for fp8
            if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
                continue
86
            # nvfp4 only supports group_size == 16
87
88
            # mxfp4 only supports group_size == 32
            if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
89
90
91
92
                continue
            # other quantization methods don't support group_size = 16
            if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
                continue
93
94
95
96
97
98
99
100

            k_blocks = thread_configs[0] // 16
            n_blocks = thread_configs[1] // 16
            threads = thread_configs[2]

            c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"

            is_zp_float_list = [False]
101
            if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
102
103
104
105
                # HQQ (is_zp_float = true) only supports
                # 4bit quantization and fp16
                is_zp_float_list.append(True)

106
107
108
109
110
111
112
113
114
115
116
117
            if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
                s_type = "vllm::kFE4M3fn"
            elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
                s_type = "vllm::kFE8M0fnu"
                if dtype == "fp16":
                    # we cannot safely dequantize e8m0 to fp16, so skip this
                    continue
            elif dtype == "fp16":
                s_type = "vllm::kFloat16"
            elif dtype == "bf16":
                s_type = "vllm::kBFloat16"

118
119
120
121
            for is_zp_float in is_zp_float_list:
                template_str = jinja2.Template(TEMPLATE).render(
                    scalar_t=c_dtype,
                    w_type_id=scalar_type + ".id()",
122
                    s_type_id=s_type + ".id()",
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
                    threads=threads,
                    thread_m_blocks=max(m_blocks, 1),
                    thread_n_blocks=n_blocks,
                    thread_k_blocks=k_blocks,
                    m_block_size_8=m_blocks == 0.5,
                    stages="pipe_stages",
                    group_blocks=group_blocks,
                    is_zp_float=is_zp_float,
                )

                all_template_str_list.append(template_str)

        file_content = FILE_HEAD + "\n\n"
        file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
        filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"

        with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
            f.write(file_content)


if __name__ == "__main__":
    remove_old_kernels()
    generate_new_kernels()