"docs/backend/send_request.ipynb" did not exist on "72e979bfb5ed031282deef800774cbcde3d572b3"
generate.py 3.19 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

import argparse
6
from enum import IntEnum
carlushuang's avatar
carlushuang committed
7
from pathlib import Path
8
from typing import List, Optional
carlushuang's avatar
carlushuang committed
9

10
11
12
13
14
15
from codegen.cmake_config import *
from codegen.ops import (
    fmha_fwd,
    fmha_fwd_splitkv,
    fmha_bwd
)
carlushuang's avatar
carlushuang committed
16
17


18
19
20
class HandlerId(IntEnum):
    LIST_BLOBS = 0
    WRITE_BLOBS = 1
carlushuang's avatar
carlushuang committed
21

22
23
24
25
handlers = {
    'fwd'         : (fmha_fwd.list_blobs, fmha_fwd.write_blobs),
    'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs),
    'bwd'         : (fmha_bwd.list_blobs, fmha_bwd.write_blobs),
Dan Yao's avatar
Dan Yao committed
26
27
}

28
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
carlushuang's avatar
carlushuang committed
29
30
31
32
33
34
    if output_dir is None:
        output_dir = Path(__file__).parent
    else:
        output_dir = Path(output_dir) / GEN_DIR

    output_dir.mkdir(parents=True, exist_ok=True)
35
36
37
38

    for api in api_list:
        handler = handlers[api][HandlerId.WRITE_BLOBS]
        handler(output_dir, kernel_filter, receipt, mask_impl)
carlushuang's avatar
carlushuang committed
39
40

# list all the files that will be generated
41
def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
carlushuang's avatar
carlushuang committed
42
43
    assert output_file is not None
    file_path = Path(output_file)
44
45
46
47

    for api in api_list:
        handler = handlers[api][HandlerId.LIST_BLOBS]
        handler(file_path, kernel_filter, receipt, mask_impl)
carlushuang's avatar
carlushuang committed
48
49
50
51

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="generate",
52
        description="gen API for CK fmha kernel",
carlushuang's avatar
carlushuang committed
53
    )
Dan Yao's avatar
Dan Yao committed
54
55
    parser.add_argument(
        "-d",
56
57
58
        "--direction", # we keep 'direction' option for backward compatibility
        "-a",
        "--api",
Dan Yao's avatar
Dan Yao committed
59
60
        default='fwd',
        required=False,
61
        help="supply API(s) to generate (default: fwd). separated by comma."
Dan Yao's avatar
Dan Yao committed
62
    )
carlushuang's avatar
carlushuang committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    parser.add_argument(
        "-o",
        "--output_dir",
        required=False,
        help="write all the blobs into a directory"
    )
    parser.add_argument(
        "-l",
        "--list_blobs",
        required=False,
        help="list all the kernels to a file"
    )
    # TODO: if using filter, must apply same value to output_dir and list_blobs
    parser.add_argument(
        "-f",
        "--filter",
        required=False,
        help="filter out kernels that need to generate, using fnmatch module"
    )

    parser.add_argument(
        "-m",
        "--mask",
        default="simplified",
        required=False,
        help="mask implementation, simplified/generic"
    )

    parser.add_argument(
        "-r",
        "--receipt",
        default=0,
        required=False,
        help="codegen receipt. 0: generate only 8xhdim coverage\n"  + \
Dan Yao's avatar
Dan Yao committed
97
98
             "  1: generate more instance to cover all hdim\n"  + \
             "  2: Only generate instance for Flash attention integration"
carlushuang's avatar
carlushuang committed
99
100
101
    )

    args = parser.parse_args()
102
    api_list = args.direction.split(',')
carlushuang's avatar
carlushuang committed
103
    if args.list_blobs is not None:
104
        list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
carlushuang's avatar
carlushuang committed
105
    else:
106
        write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)