generate.py 4.29 KB
Newer Older
Max Podkorytov's avatar
Max Podkorytov committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
import sys
from typing import List, Optional

import codegen.ops
from codegen.cmake_config import *


class HandlerId(IntEnum):
    LIST_BLOBS = 0
    WRITE_BLOBS = 1

# inspect all modules under 'codegen.ops' and register API handlers 
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
    full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
    if full_module_name not in sys.modules:
        ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
    [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
        (op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)

33
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
Max Podkorytov's avatar
Max Podkorytov committed
34
35
36
37
38
39
40
41
42
    if output_dir is None:
        output_dir = Path(__file__).parent
    else:
        output_dir = Path(output_dir) / GEN_DIR

    output_dir.mkdir(parents=True, exist_ok=True)

    for api in api_list:
        handler = handlers[api][HandlerId.WRITE_BLOBS]
43
        handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
Max Podkorytov's avatar
Max Podkorytov committed
44
45

# list all the files that will be generated
46
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None:
Max Podkorytov's avatar
Max Podkorytov committed
47
48
49
50
51
52
53
54
    assert output_file is not None
    file_path = Path(output_file)

    # create an empty file / drop its contents if it exists
    open(file_path, "w").close()

    for api in api_list:
        handler = handlers[api][HandlerId.LIST_BLOBS]
55
        handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr)
Max Podkorytov's avatar
Max Podkorytov committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="generate",
        description="gen API for CK fmha kernel",
    )
    parser.add_argument(
        "-d",
        "--direction", # we keep 'direction' option for backward compatibility
        "-a",
        "--api",
        default='fwd',
        required=False,
        help="supply API(s) to generate (default: fwd). separated by comma."
    )
    parser.add_argument(
        "-o",
        "--output_dir",
        required=False,
        help="write all the blobs into a directory"
    )
    parser.add_argument(
        "-l",
        "--list_blobs",
        required=False,
        help="list all the kernels to a file"
    )
    # TODO: if using filter, must apply same value to output_dir and list_blobs
    parser.add_argument(
        "-f",
        "--filter",
        required=False,
        help="filter out kernels that need to generate, using fnmatch module"
    )

    parser.add_argument(
        "-m",
        "--mask",
        default="simplified",
        required=False,
        help="mask implementation, simplified/generic"
    )

    parser.add_argument(
        "-r",
        "--receipt",
        default=0,
        required=False,
        help="codegen receipt. 0: generate only 8xhdim coverage\n"  + \
             "  1: generate more instance to cover all hdim\n"  + \
             "  2: Only generate instance for Flash attention integration"
    )

109
110
    parser.add_argument(
        "--score_mod_expr",
111
        default="s",
112
113
114
115
        required=False,
        help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables"
    )

116
117
118
119
120
121
122
    parser.add_argument(
        "--pre_softmax_expr",
        default="s",
        required=False,
        help="flex attention's pre_softmax function, a cpp expression with `s` variable"
    )

Max Podkorytov's avatar
Max Podkorytov committed
123
124
125
    args = parser.parse_args()
    api_list = args.direction.split(',')
    if args.list_blobs is not None:
126
        list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr)
Max Podkorytov's avatar
Max Podkorytov committed
127
    else:
128
        write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr)