build.py 3.83 KB
Newer Older
1
import concurrent.futures
Jiacheng Huang's avatar
Jiacheng Huang committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import functools
import inspect
import itertools
import pathlib

import ninetoothed
from ninetoothed.aot import _HEADER_PATH

CURRENT_FILE_PATH = pathlib.Path(__file__)

BUILD_DIRECTORY_PATH = (
    CURRENT_FILE_PATH.parent.parent.parent.parent / "build" / "ninetoothed"
)


def build(premake, constexpr_param_grid, caller, op_name, output_dir):
    headers = []
    all_param_names = []
20
    combinations = []
Jiacheng Huang's avatar
Jiacheng Huang committed
21
22
    launches = []

23
24
    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = []
Jiacheng Huang's avatar
Jiacheng Huang committed
25

26
27
28
29
30
31
        for combination in tuple(
            _generate_param_value_combinations(constexpr_param_grid)
        ):
            future = executor.submit(
                _make, premake, combination, caller, op_name, output_dir
            )
Jiacheng Huang's avatar
Jiacheng Huang committed
32

33
            futures.append(future)
Jiacheng Huang's avatar
Jiacheng Huang committed
34

35
36
        for future in concurrent.futures.as_completed(futures):
            header, param_names, combination, launch = future.result()
Jiacheng Huang's avatar
Jiacheng Huang committed
37

38
39
40
41
            headers.append(header)
            all_param_names.append(param_names)
            combinations.append(combination)
            launches.append(launch)
Jiacheng Huang's avatar
Jiacheng Huang committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    includes = "\n".join(f'#include "{header}"' for header in headers)

    param_names = list(
        functools.reduce(
            lambda x, y: dict.fromkeys(x) | dict.fromkeys(y),
            sorted(all_param_names, key=len, reverse=True),
            {},
        )
    )
    param_types = [
        "NineToothedStream",
    ] + ["NineToothedTensor" for _ in range(len(param_names) - 1)]

56
    for param_name in functools.reduce(lambda x, y: x | y, combinations, {}):
Jiacheng Huang's avatar
Jiacheng Huang committed
57
58
59
60
61
62
63
64
65
66
67
68
        param_names.append(param_name)
        param_types.append("int")

    param_decls = ", ".join(
        f"{type} {param}" for param, type in zip(param_names, param_types)
    )

    source_file_name = f"{op_name}.c"
    header_file_name = f"{op_name}.h"

    func_sig = f"NineToothedResult launch_{op_name}({param_decls})"

69
70
    joined_launches = "\n".join(launches)

Jiacheng Huang's avatar
Jiacheng Huang committed
71
72
    op_decl = f'#ifdef __cplusplus\nextern "C" {func_sig};\n#else\n{func_sig};\n#endif'
    op_def = f"""{func_sig} {{
73
{joined_launches}
Jiacheng Huang's avatar
Jiacheng Huang committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    return INFINI_STATUS_NOT_IMPLEMENTED;
}}"""

    source_content = f"""#include "{header_file_name}"

#include "infinicore.h"

{includes}\n\n{op_def}\n"""
    header_content = f"""#include "{_HEADER_PATH}"
\n{op_decl}\n"""

    (BUILD_DIRECTORY_PATH / source_file_name).write_text(source_content)
    (BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content)


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def _make(premake, combination, caller, op_name, output_dir):
    arrangement, application, tensors = premake(**combination)

    for param_name, param_value in combination.items():
        if isinstance(param_value, str):
            combination[param_name] = (
                f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}"
            )

    combination = {f"{name}_": value for name, value in combination.items()}

    kernel_name = f"{op_name}_{_generate_suffix(combination.values())}"

    ninetoothed.make(
        arrangement,
        application,
        tensors,
        caller=caller,
        kernel_name=kernel_name,
        output_dir=output_dir,
    )

    header = output_dir / f"{kernel_name}.h"
    param_names = ("stream",) + tuple(inspect.signature(application).parameters.keys())
    launch = f"""    if ({_generate_condition(combination)})
        return launch_{kernel_name}({", ".join(param_names)});"""

    return header, param_names, combination, launch


Jiacheng Huang's avatar
Jiacheng Huang committed
119
120
121
122
123
124
125
126
127
128
129
130
131
def _generate_condition(combination):
    return " && ".join(f"{param} == {value}" for param, value in combination.items())


def _generate_suffix(values):
    return "_".join(f"{value}" for value in values)


def _generate_param_value_combinations(param_grid):
    keys = list(param_grid.keys())
    value_combinations = itertools.product(*param_grid.values())

    return tuple(dict(zip(keys, combination)) for combination in value_combinations)