Commit b76ef72c authored by carlushuang's avatar carlushuang
Browse files

add gen tile template script

parent e20ed766
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
TILE_OP_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by gen_tile_op.py
"""
TILE_OP_HOST_API_CPP="""
#include "{op_host_api_file}.hpp"
float {op_host_api}({op_traits} t, {op_kargs} a, ck_tile::stream_config s)
{{
// TODO: write some dispatch code
(void)t;
{{
using problem = ck_tile::{op_problem};
using pipeline = ck_tile::{op_pipeline}<problem>;
using kernel = ck_tile::{op_kernel}<pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, 1>(kernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}}
return -1; // not supported by this API
}}
"""
TILE_OP_HOST_API_HPP="""
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/{op_name}.hpp"
struct {op_traits}
{{
// TODO: add more trait for selecting kernel
}};
struct {op_kargs} : public ck_tile::{k_op_host_args}
{{
}};
float {op_host_api}({op_traits} t, {op_kargs} a, ck_tile::stream_config s);
"""
TILE_OP_KERNEL="""
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {{
struct {k_op_host_args}
{{
// TODO: add host args
}};
template <typename Pipeline_>
struct {k_op_kernel}
{{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
struct {k_op_kargs}
{{
// TODO: add kernel args
}};
using Kargs = {k_op_kargs};
using Hargs = {k_op_host_args};
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{{
// TODO: return how many grids
(void)h;
return dim3(1);
}}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{{
(void)h;
Kargs k;
return k;
}}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() {{ return Problem::BlockSize; }}
CK_TILE_DEVICE void
operator()(Kargs kargs) const
{{
// entry point of this kernel
(void)kargs;
// Pipeline{{}}(input_window, output_window, loop_stride);
}}
}};
}} // namespace ck_tile
"""
TILE_OP_PIPELINE="""
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {{
template <typename Problem_, typename Policy_ = {k_op_policy}>
struct {k_op_pipeline}
{{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
CK_TILE_DEVICE auto
operator()()
{{
// pipeline is here
}}
}};
}} // namespace ck_tile
"""
TILE_OP_POLICY="""
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {{
struct {k_op_policy}
{{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{{
// TODO: create some discriptor
// return make_static_tile_distribution(
// tile_distribution_encoding<sequence<1>,
// tuple<sequence<Problem::IssuesPerRow,
// Problem::WarpsPerBlock,
// Problem::LanesPerRow,
// Problem::VectorSize>>,
// tuple<sequence<1>, sequence<1>>,
// tuple<sequence<1>, sequence<2>>,
// sequence<1, 1>,
// sequence<0, 3>>{{}});
}}
}};
}} // namespace ck_tile
"""
TILE_OP_PROBLEM="""
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {{
template <index_t BlockSize_ = 256>
struct {k_op_problem}
{{
static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t WarpSize = get_warp_size();
}};
}} // namespace ck_tile
"""
TILE_OP_DEVICE_HEADER="""
#pragma once
#include "ck_tile/ops/{op_name}/kernel/{op_name}_kernel.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_pipeline.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_policy.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_problem.hpp"
"""
def snake_to_pascal_case(snake_case_str):
words = snake_case_str.split('_')
#pascal_case_str = words[0].lower() + ''.join(word.title() for word in words[1:])
pascal_case_str = ''.join(word.title() for word in words)
return pascal_case_str
class tile_op_template:
def __init__(self, base_dir : Path, op_name : str):
self.base_dir = base_dir
self.op_name = op_name
@property
def op_host_api_file(self) -> str:
return self.op_name + "_api"
@property
def op_host_api(self) -> str:
return self.op_name
@property
def op_problem(self) -> str:
return self.op_name + "_problem"
@property
def op_pipeline(self) -> str:
return self.op_name + "_pipeline"
@property
def op_policy(self) -> str:
return self.op_name + "_policy"
@property
def op_kernel(self) -> str:
return self.op_name + "_kernel"
@property
def op_traits(self) -> str:
return self.op_name + "_traits"
@property
def op_kargs(self) -> str:
return self.op_name + "_kargs"
@property
def k_op_host_args(self) -> str:
return snake_to_pascal_case(self.op_name + "_host_args")
@property
def k_op_kernel(self) -> str:
return snake_to_pascal_case(self.op_name + "_kenrel")
@property
def k_op_kargs(self) -> str:
return snake_to_pascal_case(self.op_name + "_kargs")
@property
def k_op_pipeline(self) -> str:
return snake_to_pascal_case(self.op_name + "_pipeline")
@property
def k_op_policy(self) -> str:
return snake_to_pascal_case(self.op_name + "_policy")
@property
def k_op_problem(self) -> str:
return snake_to_pascal_case(self.op_name + "_problem")
def gen_host_api(self):
text_ = TILE_OP_HEADER + TILE_OP_HOST_API_CPP.format(op_host_api_file = self.op_host_api_file,
op_host_api=self.op_host_api, op_traits=self.op_traits, op_kargs=self.op_kargs,
op_problem=self.op_problem, op_pipeline=self.op_pipeline, op_kernel=self.op_kernel)
(self.base_dir / (self.op_host_api_file + ".cpp")).write_text(text_)
text_ = TILE_OP_HEADER + TILE_OP_HOST_API_HPP.format(op_name=self.op_name, op_traits=self.op_traits, op_kargs=self.op_kargs,
k_op_host_args=self.k_op_host_args, op_host_api=self.op_host_api)
(self.base_dir / (self.op_host_api_file + ".hpp")).write_text(text_)
def gen_kernel(self):
ops = self.base_dir / 'include' / 'ck_tile' / 'ops'
ops_op_kernel = ops / self.op_name / 'kernel'
ops_op_pipeline = ops / self.op_name / 'pipeline'
ops_op_kernel.mkdir(parents=True, exist_ok=True)
ops_op_pipeline.mkdir(parents=True, exist_ok=True)
# kernel
text_ = TILE_OP_HEADER + TILE_OP_KERNEL.format(k_op_host_args=self.k_op_host_args,
k_op_kernel=self.k_op_kernel, k_op_kargs=self.k_op_kargs)
(ops_op_kernel / (self.op_name + "_kernel.hpp")).write_text(text_)
# pipeline
text_ = TILE_OP_HEADER + TILE_OP_PIPELINE.format(op_name=self.op_name,
k_op_policy=self.k_op_policy, k_op_pipeline=self.k_op_pipeline)
(ops_op_pipeline / (self.op_name + "_pipeline.hpp")).write_text(text_)
# policy
text_ = TILE_OP_HEADER + TILE_OP_POLICY.format(k_op_policy=self.k_op_policy)
(ops_op_pipeline / (self.op_name + "_policy.hpp")).write_text(text_)
# problem
text_ = TILE_OP_HEADER + TILE_OP_PROBLEM.format(k_op_problem=self.k_op_problem)
(ops_op_pipeline / (self.op_name + "_problem.hpp")).write_text(text_)
# one for all header
text_ = TILE_OP_HEADER + TILE_OP_DEVICE_HEADER.format(op_name = self.op_name)
(ops / (self.op_name + ".hpp")).write_text(text_)
def gen(self):
self.gen_host_api()
self.gen_kernel()
def gen_tile_op(args):
name = args.op_name.lower()
base_dir = Path(args.directory) / name
base_dir.mkdir(parents=True, exist_ok=True)
op = tile_op_template(base_dir, name)
op.gen()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="gen_tile_op",
description="generate ck_tile op template (you still need to write kernel :))",
)
parser.add_argument(
"-d",
"--directory",
default='./',
required=False,
help="where to generate the op, default is current directory"
)
parser.add_argument(
"-p",
"--op_name",
default='foo',
required=False,
help="operator name to generate"
)
args = parser.parse_args()
gen_tile_op(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment