# SPDX-License-Identifier: MIT # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import argparse from enum import IntEnum from pathlib import Path TILE_OP_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // auto generated by gen_tile_op.py """ TILE_OP_HOST_API_CPP=""" #include "{op_host_api_file}.hpp" float {op_host_api}({op_traits} t, {op_kargs} a, ck_tile::stream_config s) {{ // TODO: write some dispatch code (void)t; {{ using problem = ck_tile::{op_problem}; using pipeline = ck_tile::{op_pipeline}; using kernel = ck_tile::{op_kernel}; auto kargs = kernel::MakeKargs(a); const dim3 grids = kernel::GridSize(a); constexpr dim3 blocks = kernel::BlockSize(); float ave_time = ck_tile::launch_kernel(s, ck_tile::make_kernel(kernel{{}}, grids, blocks, 0, kargs)); return ave_time; }} return -1; // not supported by this API }} """ TILE_OP_HOST_API_HPP=""" #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/ops/{op_name}.hpp" struct {op_traits} {{ // TODO: add more trait for selecting kernel }}; struct {op_kargs} : public ck_tile::{k_op_host_args} {{ }}; float {op_host_api}({op_traits} t, {op_kargs} a, ck_tile::stream_config s); """ TILE_OP_KERNEL=""" #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/hip_check_error.hpp" #include #include namespace ck_tile {{ struct {k_op_host_args} {{ // TODO: add host args }}; template struct {k_op_kernel} {{ using Pipeline = remove_cvref_t; using Problem = remove_cvref_t; struct {k_op_kargs} {{ // TODO: add kernel args }}; using Kargs = {k_op_kargs}; using Hargs = {k_op_host_args}; CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) {{ // TODO: return how many grids (void)h; return dim3(1); }} CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) {{ (void)h; Kargs k; return k; }} CK_TILE_HOST_DEVICE static constexpr auto BlockSize() {{ return Problem::BlockSize; }} CK_TILE_DEVICE void operator()(Kargs kargs) const {{ // entry point of this kernel (void)kargs; // Pipeline{{}}(input_window, output_window, loop_stride); }} }}; }} // namespace ck_tile """ TILE_OP_PIPELINE=""" #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/{op_name}/pipeline/{op_name}_policy.hpp" #include #include namespace ck_tile {{ template struct {k_op_pipeline} {{ // TODO: this kernel only support warp per row using Problem = remove_cvref_t; using Policy = remove_cvref_t; CK_TILE_DEVICE auto operator()() {{ // pipeline is here }} }}; }} // namespace ck_tile """ TILE_OP_POLICY=""" #pragma once #include "ck_tile/core.hpp" namespace ck_tile {{ struct {k_op_policy} {{ template CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() {{ // TODO: create some discriptor // return make_static_tile_distribution( // tile_distribution_encoding, // tuple>, // tuple, sequence<1>>, // tuple, sequence<2>>, // sequence<1, 1>, // sequence<0, 3>>{{}}); }} }}; }} // namespace ck_tile """ TILE_OP_PROBLEM=""" #pragma once #include "ck_tile/core.hpp" #include #include namespace ck_tile {{ template struct {k_op_problem} {{ static constexpr index_t BlockSize = BlockSize_; static constexpr index_t WarpSize = get_warp_size(); }}; }} // namespace ck_tile """ TILE_OP_DEVICE_HEADER=""" #pragma once #include "ck_tile/ops/{op_name}/kernel/{op_name}_kernel.hpp" #include "ck_tile/ops/{op_name}/pipeline/{op_name}_pipeline.hpp" #include "ck_tile/ops/{op_name}/pipeline/{op_name}_policy.hpp" #include "ck_tile/ops/{op_name}/pipeline/{op_name}_problem.hpp" """ def snake_to_pascal_case(snake_case_str): words = snake_case_str.split('_') #pascal_case_str = words[0].lower() + ''.join(word.title() for word in words[1:]) pascal_case_str = ''.join(word.title() for word in words) return pascal_case_str class tile_op_template: def __init__(self, base_dir : Path, op_name : str): self.base_dir = base_dir self.op_name = op_name @property def op_host_api_file(self) -> str: return self.op_name + "_api" @property def op_host_api(self) -> str: return self.op_name @property def op_problem(self) -> str: return self.op_name + "_problem" @property def op_pipeline(self) -> str: return self.op_name + "_pipeline" @property def op_policy(self) -> str: return self.op_name + "_policy" @property def op_kernel(self) -> str: return self.op_name + "_kernel" @property def op_traits(self) -> str: return self.op_name + "_traits" @property def op_kargs(self) -> str: return self.op_name + "_kargs" @property def k_op_host_args(self) -> str: return snake_to_pascal_case(self.op_name + "_host_args") @property def k_op_kernel(self) -> str: return snake_to_pascal_case(self.op_name + "_kenrel") @property def k_op_kargs(self) -> str: return snake_to_pascal_case(self.op_name + "_kargs") @property def k_op_pipeline(self) -> str: return snake_to_pascal_case(self.op_name + "_pipeline") @property def k_op_policy(self) -> str: return snake_to_pascal_case(self.op_name + "_policy") @property def k_op_problem(self) -> str: return snake_to_pascal_case(self.op_name + "_problem") def gen_host_api(self): text_ = TILE_OP_HEADER + TILE_OP_HOST_API_CPP.format(op_host_api_file = self.op_host_api_file, op_host_api=self.op_host_api, op_traits=self.op_traits, op_kargs=self.op_kargs, op_problem=self.op_problem, op_pipeline=self.op_pipeline, op_kernel=self.op_kernel) (self.base_dir / (self.op_host_api_file + ".cpp")).write_text(text_) text_ = TILE_OP_HEADER + TILE_OP_HOST_API_HPP.format(op_name=self.op_name, op_traits=self.op_traits, op_kargs=self.op_kargs, k_op_host_args=self.k_op_host_args, op_host_api=self.op_host_api) (self.base_dir / (self.op_host_api_file + ".hpp")).write_text(text_) def gen_kernel(self): ops = self.base_dir / 'include' / 'ck_tile' / 'ops' ops_op_kernel = ops / self.op_name / 'kernel' ops_op_pipeline = ops / self.op_name / 'pipeline' ops_op_kernel.mkdir(parents=True, exist_ok=True) ops_op_pipeline.mkdir(parents=True, exist_ok=True) # kernel text_ = TILE_OP_HEADER + TILE_OP_KERNEL.format(k_op_host_args=self.k_op_host_args, k_op_kernel=self.k_op_kernel, k_op_kargs=self.k_op_kargs) (ops_op_kernel / (self.op_name + "_kernel.hpp")).write_text(text_) # pipeline text_ = TILE_OP_HEADER + TILE_OP_PIPELINE.format(op_name=self.op_name, k_op_policy=self.k_op_policy, k_op_pipeline=self.k_op_pipeline) (ops_op_pipeline / (self.op_name + "_pipeline.hpp")).write_text(text_) # policy text_ = TILE_OP_HEADER + TILE_OP_POLICY.format(k_op_policy=self.k_op_policy) (ops_op_pipeline / (self.op_name + "_policy.hpp")).write_text(text_) # problem text_ = TILE_OP_HEADER + TILE_OP_PROBLEM.format(k_op_problem=self.k_op_problem) (ops_op_pipeline / (self.op_name + "_problem.hpp")).write_text(text_) # one for all header text_ = TILE_OP_HEADER + TILE_OP_DEVICE_HEADER.format(op_name = self.op_name) (ops / (self.op_name + ".hpp")).write_text(text_) def gen(self): self.gen_host_api() self.gen_kernel() def gen_tile_op(args): name = args.op_name.lower() base_dir = Path(args.directory) / name base_dir.mkdir(parents=True, exist_ok=True) op = tile_op_template(base_dir, name) op.gen() if __name__ == "__main__": parser = argparse.ArgumentParser( prog="gen_tile_op", description="generate ck_tile op template (you still need to write kernel :))", ) parser.add_argument( "-d", "--directory", default='./', required=False, help="where to generate the op, default is current directory" ) parser.add_argument( "-p", "--op_name", default='foo', required=False, help="operator name to generate" ) args = parser.parse_args() gen_tile_op(args)