generate_cpu_attn_dispatch.py 6.35 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Generate CPU attention dispatch switch cases and kernel instantiations.
"""

import os

# Head dimensions divisible by 32 (support all ISAs)
HEAD_DIMS_32 = [32, 64, 96, 128, 160, 192, 224, 256]

# Head dimensions divisible by 16 but not 32 (VEC16 only)
HEAD_DIMS_16 = [80, 112]

# ISA types
ISA_TYPES = {
    "AMX": 0,
    "VEC": 1,
    "VEC16": 2,
    "NEON": 3,
    "VXE": 4,
}

# ISAs supported for head_dims divisible by 32
ISA_FOR_32 = ["AMX", "NEON", "VEC", "VEC16", "VXE"]

# ISAs supported for head_dims divisible by 16 only
ISA_FOR_16 = ["VEC16"]


def encode_params(head_dim: int, isa_type: str) -> int:
    """Encode head_dim and ISA type into a single int64_t."""
    isa_val = ISA_TYPES[isa_type]
    # Encoding: (head_dim << 8) | isa_type
    # This allows head_dim up to 2^56 - 1 and 256 ISA types
    return (head_dim << 8) | isa_val


def generate_cases_for_isa_group(isa_list: list[str]) -> str:
    """Generate switch cases for a specific ISA group."""
    cases = []

    # Generate cases for head_dims divisible by 32
    for head_dim in HEAD_DIMS_32:
        for isa in isa_list:
            if isa not in ISA_FOR_32:
                continue
            encoded = encode_params(head_dim, isa)
            case_str = (
                f"""      case {encoded}LL: {{ """
                f"""/* head_dim={head_dim}, isa={isa} */ \\"""
                f"""
        constexpr size_t head_dim = {head_dim}; \\"""
                f"""
        using attn_impl = cpu_attention::AttentionImpl<"""
                f"""cpu_attention::ISA::{isa}, \\"""
                f"""
                                                       """
                f"""scalar_t, head_dim>; \\"""
                f"""
        return __VA_ARGS__(); \\"""
                f"""
      }} \\"""
            )
            cases.append(case_str)

    # Generate cases for head_dims divisible by 16 only
    for head_dim in HEAD_DIMS_16:
        for isa in isa_list:
            encoded = encode_params(head_dim, isa)
            case_str = (
                f"""      case {encoded}LL: {{ """
                f"""/* head_dim={head_dim}, isa={isa} """
                f"""(using VEC16) */ \\"""
                f"""
        constexpr size_t head_dim = {head_dim}; \\"""
                f"""
        using attn_impl = cpu_attention::AttentionImpl<"""
                f"""cpu_attention::ISA::VEC16, \\"""
                f"""
                                                       """
                f"""scalar_t, head_dim>; \\"""
                f"""
        return __VA_ARGS__(); \\"""
                f"""
      }} \\"""
            )
            cases.append(case_str)

    return "\n".join(cases)


def generate_helper_function() -> str:
    """Generate helper function to encode parameters."""
    return """
inline int64_t encode_cpu_attn_params(int64_t head_dim, cpu_attention::ISA isa) {
  return (head_dim << 8) | static_cast<int64_t>(isa);
}
"""


def generate_header_file() -> str:
    """Generate the complete header file content."""
    header = """// auto generated by generate_cpu_attn_dispatch.py
// clang-format off

#ifndef CPU_ATTN_DISPATCH_GENERATED_H
#define CPU_ATTN_DISPATCH_GENERATED_H

#include "cpu_attn_vec.hpp"
#include "cpu_attn_vec16.hpp"

#ifdef CPU_CAPABILITY_AMXBF16
  #include "cpu_attn_amx.hpp"
#endif

#ifdef __aarch64__
  #include "cpu_attn_neon.hpp"
#endif

#ifdef __s390x__
  #include "cpu_attn_vxe.hpp"
#endif

"""

    header += generate_helper_function()

    # Generate dispatch macro with conditional compilation for different ISA sets
    header += """
// Dispatch macro using encoded parameters
"""

    # x86_64 with AMX
    header += """#if defined(CPU_CAPABILITY_AMXBF16)
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
  [&] { \\
    int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
    switch (encoded_params) { \\
"""
    header += generate_cases_for_isa_group(["AMX", "VEC", "VEC16"])
    header += """
      default: { \\
        TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
                    std::to_string(HEAD_DIM) + " isa=" + \\
                    std::to_string(static_cast<int>(ISA_TYPE))); \\
      } \\
    } \\
  }()

"""

    # ARM64 with NEON
    header += """#elif defined(__aarch64__)
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
  [&] { \\
    int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
    switch (encoded_params) { \\
"""
    header += generate_cases_for_isa_group(["NEON", "VEC", "VEC16"])
    header += """
      default: { \\
        TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
                    std::to_string(HEAD_DIM) + " isa=" + \\
                    std::to_string(static_cast<int>(ISA_TYPE))); \\
      } \\
    } \\
  }()

"""

    # s390x with VXE
    header += """#elif defined(__s390x__)
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
  [&] { \\
    int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
    switch (encoded_params) { \\
"""
    header += generate_cases_for_isa_group(["VXE", "VEC", "VEC16"])
    header += """
      default: { \\
        TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
                    std::to_string(HEAD_DIM) + " isa=" + \\
                    std::to_string(static_cast<int>(ISA_TYPE))); \\
      } \\
    } \\
  }()

"""

    # Fallback: VEC and VEC16 only
    header += """#else
#define CPU_ATTN_DISPATCH(HEAD_DIM, ISA_TYPE, ...) \\
  [&] { \\
    int64_t encoded_params = encode_cpu_attn_params(HEAD_DIM, ISA_TYPE); \\
    switch (encoded_params) { \\
"""
    header += generate_cases_for_isa_group(["VEC", "VEC16"])
    header += """
      default: { \\
        TORCH_CHECK(false, "Unsupported CPU attention configuration: head_dim=" + \\
                    std::to_string(HEAD_DIM) + " isa=" + \\
                    std::to_string(static_cast<int>(ISA_TYPE))); \\
      } \\
    } \\
  }()

#endif  /* CPU_CAPABILITY_AMXBF16 / __aarch64__ / __s390x__ */

#endif  // CPU_ATTN_DISPATCH_GENERATED_H
"""

    return header


def main():
    output_path = os.path.join(
        os.path.dirname(__file__), "cpu_attn_dispatch_generated.h"
    )

    with open(output_path, "w") as f:
        f.write(generate_header_file())


if __name__ == "__main__":
    main()