generate.py 3.67 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

import argparse
6
from enum import IntEnum
carlushuang's avatar
carlushuang committed
7
from pathlib import Path
8
9
import pkgutil
import sys
10
from typing import List, Optional
carlushuang's avatar
carlushuang committed
11

12
import codegen.ops
13
from codegen.cmake_config import *
carlushuang's avatar
carlushuang committed
14
15


16
17
18
class HandlerId(IntEnum):
    LIST_BLOBS = 0
    WRITE_BLOBS = 1
carlushuang's avatar
carlushuang committed
19

20
21
22
23
24
25
26
27
28
29
30
31
# inspect all modules under 'codegen.ops' and register API handlers 
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
    full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
    if full_module_name not in sys.modules:
        ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
    [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
        (op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)
Dan Yao's avatar
Dan Yao committed
32

33
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
carlushuang's avatar
carlushuang committed
34
35
36
37
38
39
    if output_dir is None:
        output_dir = Path(__file__).parent
    else:
        output_dir = Path(output_dir) / GEN_DIR

    output_dir.mkdir(parents=True, exist_ok=True)
40
41
42
43

    for api in api_list:
        handler = handlers[api][HandlerId.WRITE_BLOBS]
        handler(output_dir, kernel_filter, receipt, mask_impl)
carlushuang's avatar
carlushuang committed
44
45

# list all the files that will be generated
46
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
carlushuang's avatar
carlushuang committed
47
48
    assert output_file is not None
    file_path = Path(output_file)
49

50
51
52
    # create an empty file / drop its contents if it exists
    open(file_path, "w").close()

53
54
55
    for api in api_list:
        handler = handlers[api][HandlerId.LIST_BLOBS]
        handler(file_path, kernel_filter, receipt, mask_impl)
carlushuang's avatar
carlushuang committed
56
57
58
59

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="generate",
60
        description="gen API for CK fmha kernel",
carlushuang's avatar
carlushuang committed
61
    )
Dan Yao's avatar
Dan Yao committed
62
63
    parser.add_argument(
        "-d",
64
65
66
        "--direction", # we keep 'direction' option for backward compatibility
        "-a",
        "--api",
Dan Yao's avatar
Dan Yao committed
67
68
        default='fwd',
        required=False,
69
        help="supply API(s) to generate (default: fwd). separated by comma."
Dan Yao's avatar
Dan Yao committed
70
    )
carlushuang's avatar
carlushuang committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    parser.add_argument(
        "-o",
        "--output_dir",
        required=False,
        help="write all the blobs into a directory"
    )
    parser.add_argument(
        "-l",
        "--list_blobs",
        required=False,
        help="list all the kernels to a file"
    )
    # TODO: if using filter, must apply same value to output_dir and list_blobs
    parser.add_argument(
        "-f",
        "--filter",
        required=False,
        help="filter out kernels that need to generate, using fnmatch module"
    )

    parser.add_argument(
        "-m",
        "--mask",
        default="simplified",
        required=False,
        help="mask implementation, simplified/generic"
    )

    parser.add_argument(
        "-r",
        "--receipt",
        default=0,
        required=False,
        help="codegen receipt. 0: generate only 8xhdim coverage\n"  + \
Dan Yao's avatar
Dan Yao committed
105
             "  1: generate more instance to cover all hdim\n"  + \
106
107
             "  2: Only generate instance for Flash attention integration\n"  + \
             "  4: Only generate instance for PyTorch integration"
carlushuang's avatar
carlushuang committed
108
109
110
    )

    args = parser.parse_args()
111
    api_list = args.direction.split(',')
carlushuang's avatar
carlushuang committed
112
    if args.list_blobs is not None:
113
        list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
carlushuang's avatar
carlushuang committed
114
    else:
115
        write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)