generate.py 3.6 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

import argparse
6
from enum import IntEnum
carlushuang's avatar
carlushuang committed
7
from pathlib import Path
8
9
import pkgutil
import sys
10
from typing import List, Optional
carlushuang's avatar
carlushuang committed
11

12
import codegen.ops
13
from codegen.cmake_config import *
carlushuang's avatar
carlushuang committed
14
15


16
17
18
class HandlerId(IntEnum):
    LIST_BLOBS = 0
    WRITE_BLOBS = 1
carlushuang's avatar
carlushuang committed
19

20
21
22
23
24
25
26
27
28
29
30
31
# inspect all modules under 'codegen.ops' and register API handlers 
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
    full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
    if full_module_name not in sys.modules:
        ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
    [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
        (op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)
Dan Yao's avatar
Dan Yao committed
32

33
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
carlushuang's avatar
carlushuang committed
34
35
36
37
38
39
    if output_dir is None:
        output_dir = Path(__file__).parent
    else:
        output_dir = Path(output_dir) / GEN_DIR

    output_dir.mkdir(parents=True, exist_ok=True)
40
41
42
43

    for api in api_list:
        handler = handlers[api][HandlerId.WRITE_BLOBS]
        handler(output_dir, kernel_filter, receipt, mask_impl)
carlushuang's avatar
carlushuang committed
44
45

# list all the files that will be generated
46
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
carlushuang's avatar
carlushuang committed
47
48
    assert output_file is not None
    file_path = Path(output_file)
49

50
51
52
    # create an empty file / drop its contents if it exists
    open(file_path, "w").close()

53
54
55
    for api in api_list:
        handler = handlers[api][HandlerId.LIST_BLOBS]
        handler(file_path, kernel_filter, receipt, mask_impl)
carlushuang's avatar
carlushuang committed
56
57
58
59

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="generate",
60
        description="gen API for CK fmha kernel",
carlushuang's avatar
carlushuang committed
61
    )
Dan Yao's avatar
Dan Yao committed
62
63
    parser.add_argument(
        "-d",
64
65
66
        "--direction", # we keep 'direction' option for backward compatibility
        "-a",
        "--api",
Jim's avatar
Jim committed
67
        default='bwd',
Dan Yao's avatar
Dan Yao committed
68
        required=False,
69
        help="supply API(s) to generate (default: fwd). separated by comma."
Dan Yao's avatar
Dan Yao committed
70
    )
carlushuang's avatar
carlushuang committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    parser.add_argument(
        "-o",
        "--output_dir",
        required=False,
        help="write all the blobs into a directory"
    )
    parser.add_argument(
        "-l",
        "--list_blobs",
        required=False,
        help="list all the kernels to a file"
    )
    # TODO: if using filter, must apply same value to output_dir and list_blobs
    parser.add_argument(
        "-f",
        "--filter",
        required=False,
        help="filter out kernels that need to generate, using fnmatch module"
    )

    parser.add_argument(
        "-m",
        "--mask",
        default="simplified",
        required=False,
        help="mask implementation, simplified/generic"
    )

    parser.add_argument(
        "-r",
        "--receipt",
Jim's avatar
update  
Jim committed
102
        default=3,
carlushuang's avatar
carlushuang committed
103
104
        required=False,
        help="codegen receipt. 0: generate only 8xhdim coverage\n"  + \
Dan Yao's avatar
Dan Yao committed
105
106
             "  1: generate more instance to cover all hdim\n"  + \
             "  2: Only generate instance for Flash attention integration"
carlushuang's avatar
carlushuang committed
107
108
109
    )

    args = parser.parse_args()
110
    api_list = args.direction.split(',')
carlushuang's avatar
carlushuang committed
111
    if args.list_blobs is not None:
112
        list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
carlushuang's avatar
carlushuang committed
113
    else:
114
        write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)