generate_kernels.py 4.28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import glob
import itertools
import os
import subprocess

import jinja2

FILE_HEAD = """
// auto generated by generate.py
// clang-format off

#include "kernel.h"
#include "marlin_template.h"

namespace MARLIN_NAMESPACE_NAME {
""".strip()

TEMPLATE = ("template __global__ void Marlin<"
            "{{scalar_t}}, "
            "{{w_type_id}}, "
            "{{threads}}, "
            "{{thread_m_blocks}}, "
            "{{thread_n_blocks}}, "
            "{{thread_k_blocks}}, "
            "{{'true' if m_block_size_8 else 'false'}}, "
            "{{stages}}, "
            "{{group_blocks}}, "
            "{{'true' if is_zp_float else 'false'}}>"
            "( MARLIN_KERNEL_PARAMS );")

# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
35
36
37
38
SCALAR_TYPES = [
    "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
    "vllm::kFE2M1f"
]
39
40
41
42
43
44
45
46
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
                  (128, 64, 128)]

THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
#   = 0 : act order case
#   = -1 : channelwise quantization
#   > 0 : group_size=16*group_blocks
47
GROUP_BLOCKS = [0, 1, -1, 2, 4, 8]
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
DTYPES = ["fp16", "bf16"]


def remove_old_kernels():
    for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
        subprocess.call(["rm", "-f", filename])


def generate_new_kernels():
    for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
        all_template_str_list = []

        for group_blocks, m_blocks, thread_configs in itertools.product(
                GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):

            # act order case only support gptq-int4 and gptq-int8
            if group_blocks == 0 and scalar_type not in [
                    "vllm::kU4B8", "vllm::kU8B128"
            ]:
                continue
            if thread_configs[2] == 256:
                # for small batch (m_blocks == 1), we only need (128, 128, 256)
                # for large batch (m_blocks > 1), we only need (64, 256, 256)
                if m_blocks <= 1 and thread_configs[0] != 128:
                    continue
                if m_blocks > 1 and thread_configs[0] != 64:
                    continue

            # we only support channelwise quantization and group_size == 128
            # for fp8
            if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
                continue
80
81
82
83
84
85
            # nvfp4 only supports group_size == 16
            if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
                continue
            # other quantization methods don't support group_size = 16
            if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
                continue
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

            k_blocks = thread_configs[0] // 16
            n_blocks = thread_configs[1] // 16
            threads = thread_configs[2]

            c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"

            is_zp_float_list = [False]
            if dtype == "fp16" and scalar_type == "vllm::kU4" and \
                    group_blocks == 4:
                # HQQ (is_zp_float = true) only supports
                # 4bit quantization and fp16
                is_zp_float_list.append(True)

            for is_zp_float in is_zp_float_list:
                template_str = jinja2.Template(TEMPLATE).render(
                    scalar_t=c_dtype,
                    w_type_id=scalar_type + ".id()",
                    threads=threads,
                    thread_m_blocks=max(m_blocks, 1),
                    thread_n_blocks=n_blocks,
                    thread_k_blocks=k_blocks,
                    m_block_size_8=m_blocks == 0.5,
                    stages="pipe_stages",
                    group_blocks=group_blocks,
                    is_zp_float=is_zp_float,
                )

                all_template_str_list.append(template_str)

        file_content = FILE_HEAD + "\n\n"
        file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
        filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"

        with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
            f.write(file_content)


if __name__ == "__main__":
    remove_old_kernels()
    generate_new_kernels()