Unverified Commit a19efb54 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/232: 接入九齿

parent 301cc55c
import importlib
import pathlib
from infiniop.ninetoothed.build import BUILD_DIRECTORY_PATH
CURRENT_FILE_PATH = pathlib.Path(__file__)
SRC_DIR_PATH = CURRENT_FILE_PATH.parent.parent / "src"
def _find_and_build_ops():
ops_path = SRC_DIR_PATH / "infiniop" / "ops"
for op_dir in ops_path.iterdir():
ninetoothed_path = op_dir / "ninetoothed"
if ninetoothed_path.is_dir():
module_path = ninetoothed_path / "build"
relative_path = module_path.relative_to(SRC_DIR_PATH)
import_name = ".".join(relative_path.parts)
module = importlib.import_module(import_name)
module.build()
if __name__ == "__main__":
BUILD_DIRECTORY_PATH.mkdir(exist_ok=True)
_find_and_build_ops()
...@@ -62,7 +62,7 @@ def format_file(file: Path, check: bool, formatter) -> bool: ...@@ -62,7 +62,7 @@ def format_file(file: Path, check: bool, formatter) -> bool:
text=True, text=True,
check=True, check=True,
) )
if process.stderr: if process.returncode != 0:
print(f"{Fore.YELLOW}{file} is not formatted.{Style.RESET_ALL}") print(f"{Fore.YELLOW}{file} is not formatted.{Style.RESET_ALL}")
print( print(
f"Use {Fore.CYAN}{formatter} {file}{Style.RESET_ALL} to format it." f"Use {Fore.CYAN}{formatter} {file}{Style.RESET_ALL} to format it."
......
import functools
import inspect
import itertools
import pathlib
import ninetoothed
from ninetoothed.aot import _HEADER_PATH
CURRENT_FILE_PATH = pathlib.Path(__file__)
BUILD_DIRECTORY_PATH = (
CURRENT_FILE_PATH.parent.parent.parent.parent / "build" / "ninetoothed"
)
def build(premake, constexpr_param_grid, caller, op_name, output_dir):
headers = []
all_param_names = []
launches = []
for combination in _generate_param_value_combinations(constexpr_param_grid):
arrangement, application, tensors = premake(**combination)
for param_name, param_value in combination.items():
if isinstance(param_value, str):
combination[param_name] = (
f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F')}"
)
combination = {f"{name}_": value for name, value in combination.items()}
kernel_name = f"{op_name}_{_generate_suffix(combination.values())}"
ninetoothed.make(
arrangement,
application,
tensors,
caller=caller,
kernel_name=kernel_name,
output_dir=output_dir,
)
header = output_dir / f"{kernel_name}.h"
param_names = ("stream",) + tuple(
inspect.signature(application).parameters.keys()
)
launch = f""" if ({_generate_condition(combination)})
return launch_{kernel_name}({", ".join(param_names)});"""
headers.append(header)
all_param_names.append(param_names)
launches.append(launch)
includes = "\n".join(f'#include "{header}"' for header in headers)
param_names = list(
functools.reduce(
lambda x, y: dict.fromkeys(x) | dict.fromkeys(y),
sorted(all_param_names, key=len, reverse=True),
{},
)
)
param_types = [
"NineToothedStream",
] + ["NineToothedTensor" for _ in range(len(param_names) - 1)]
for param_name in combination:
param_names.append(param_name)
param_types.append("int")
param_decls = ", ".join(
f"{type} {param}" for param, type in zip(param_names, param_types)
)
source_file_name = f"{op_name}.c"
header_file_name = f"{op_name}.h"
func_sig = f"NineToothedResult launch_{op_name}({param_decls})"
op_decl = f'#ifdef __cplusplus\nextern "C" {func_sig};\n#else\n{func_sig};\n#endif'
op_def = f"""{func_sig} {{
{"\n".join(launches)}
return INFINI_STATUS_NOT_IMPLEMENTED;
}}"""
source_content = f"""#include "{header_file_name}"
#include "infinicore.h"
{includes}\n\n{op_def}\n"""
header_content = f"""#include "{_HEADER_PATH}"
\n{op_decl}\n"""
(BUILD_DIRECTORY_PATH / source_file_name).write_text(source_content)
(BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content)
def _generate_condition(combination):
return " && ".join(f"{param} == {value}" for param, value in combination.items())
def _generate_suffix(values):
return "_".join(f"{value}" for value in values)
def _generate_param_value_combinations(param_grid):
keys = list(param_grid.keys())
value_combinations = itertools.product(*param_grid.values())
return tuple(dict(zip(keys, combination)) for combination in value_combinations)
...@@ -15,7 +15,8 @@ target("infiniop-cuda") ...@@ -15,7 +15,8 @@ target("infiniop-cuda")
set_policy("build.cuda.devlink", true) set_policy("build.cuda.devlink", true)
set_toolchains("cuda") set_toolchains("cuda")
add_links("cublas", "cudnn") add_links("cuda", "cublas", "cudnn")
add_linkdirs(CUDA_ROOT .. "/lib64/stubs")
add_cugencodes("native") add_cugencodes("native")
if is_plat("windows") then if is_plat("windows") then
...@@ -38,7 +39,7 @@ target("infiniop-cuda") ...@@ -38,7 +39,7 @@ target("infiniop-cuda")
end end
set_languages("cxx17") set_languages("cxx17")
add_files("../src/infiniop/devices/cuda/*.cu", "../src/infiniop/ops/*/cuda/*.cu") add_files("../src/infiniop/devices/cuda/*.cu", "../src/infiniop/ops/*/cuda/*.cu", "../build/ninetoothed/*.c")
target_end() target_end()
target("infinirt-cuda") target("infinirt-cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment