gen_tile_op.py 9.36 KB
Newer Older
carlushuang's avatar
carlushuang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

import argparse
from enum import IntEnum
from pathlib import Path

TILE_OP_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by gen_tile_op.py

"""

TILE_OP_HOST_API_CPP="""
#include "{op_host_api_file}.hpp"

float {op_host_api}({op_traits} t, {op_kargs} a, ck_tile::stream_config s)
{{
    // TODO: write some dispatch code
    (void)t;
    {{
        using problem = ck_tile::{op_problem};
        using pipeline = ck_tile::{op_pipeline}<problem>;
        using kernel   = ck_tile::{op_kernel}<pipeline>;

        auto kargs            = kernel::MakeKargs(a);
        const dim3 grids      = kernel::GridSize(a);
        constexpr dim3 blocks = kernel::BlockSize();

        float ave_time = ck_tile::launch_kernel(s,
            ck_tile::make_kernel<blocks.x, 1>(kernel{{}}, grids, blocks, 0, kargs));
        return ave_time;
    }}
    return -1;  // not supported by this API
}}

"""

TILE_OP_HOST_API_HPP="""
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/{op_name}.hpp"

struct {op_traits}
{{
    // TODO: add more trait for selecting kernel
}};

struct {op_kargs} : public ck_tile::{k_op_host_args}
{{
}};

float {op_host_api}({op_traits} t, {op_kargs} a, ck_tile::stream_config s);
"""

TILE_OP_KERNEL="""
#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>

namespace ck_tile {{

struct {k_op_host_args}
{{
    // TODO: add host args
}};

template <typename Pipeline_>
struct {k_op_kernel}
{{
    using Pipeline = remove_cvref_t<Pipeline_>;
    using Problem  = remove_cvref_t<typename Pipeline::Problem>;

    struct {k_op_kargs}
    {{
        // TODO: add kernel args
    }};

    using Kargs = {k_op_kargs};
    using Hargs = {k_op_host_args};

    CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
    {{
        // TODO: return how many grids
        (void)h;
        return dim3(1);
    }}

    CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
    {{
        (void)h;
        Kargs k;
        return k;
    }}

    CK_TILE_HOST_DEVICE static constexpr auto BlockSize() {{ return Problem::BlockSize; }}

    CK_TILE_DEVICE void
    operator()(Kargs kargs) const
    {{
        // entry point of this kernel
        (void)kargs;
        // Pipeline{{}}(input_window, output_window, loop_stride);
    }}
}};
}} // namespace ck_tile

"""

TILE_OP_PIPELINE="""
#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_policy.hpp"
#include <string>
#include <type_traits>

namespace ck_tile {{

template <typename Problem_, typename Policy_ = {k_op_policy}>
struct {k_op_pipeline}
{{
    // TODO: this kernel only support warp per row
    using Problem      = remove_cvref_t<Problem_>;
    using Policy       = remove_cvref_t<Policy_>;

    CK_TILE_DEVICE auto
    operator()()
    {{
        // pipeline is here
    }}
}};
}} // namespace ck_tile

"""

TILE_OP_POLICY="""
#pragma once

#include "ck_tile/core.hpp"

namespace ck_tile {{

struct {k_op_policy}
{{
    template <typename Problem>
    CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
    {{
        // TODO: create some discriptor
        // return make_static_tile_distribution(
        //     tile_distribution_encoding<sequence<1>,
        //                                tuple<sequence<Problem::IssuesPerRow,
        //                                               Problem::WarpsPerBlock,
        //                                               Problem::LanesPerRow,
        //                                               Problem::VectorSize>>,
        //                                tuple<sequence<1>, sequence<1>>,
        //                                tuple<sequence<1>, sequence<2>>,
        //                                sequence<1, 1>,
        //                                sequence<0, 3>>{{}});
    }}
}};
}} // namespace ck_tile

"""

TILE_OP_PROBLEM="""
#pragma once

#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>

namespace ck_tile {{

template <index_t BlockSize_     = 256>
struct {k_op_problem}
{{
    static constexpr index_t BlockSize     = BlockSize_;
    static constexpr index_t WarpSize      = get_warp_size();
}};
}} // namespace ck_tile

"""

TILE_OP_DEVICE_HEADER="""
#pragma once

#include "ck_tile/ops/{op_name}/kernel/{op_name}_kernel.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_pipeline.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_policy.hpp"
#include "ck_tile/ops/{op_name}/pipeline/{op_name}_problem.hpp"

"""

def snake_to_pascal_case(snake_case_str):
    words = snake_case_str.split('_')
    #pascal_case_str = words[0].lower() + ''.join(word.title() for word in words[1:])
    pascal_case_str = ''.join(word.title() for word in words)
    return pascal_case_str

class tile_op_template:
    def __init__(self, base_dir : Path, op_name : str):
        self.base_dir = base_dir
        self.op_name = op_name

    @property
    def op_host_api_file(self) -> str:
        return self.op_name + "_api"

    @property
    def op_host_api(self) -> str:
        return self.op_name

    @property
    def op_problem(self) -> str:
        return self.op_name + "_problem"

    @property
    def op_pipeline(self) -> str:
        return self.op_name + "_pipeline"

    @property
    def op_policy(self) -> str:
        return self.op_name + "_policy"

    @property
    def op_kernel(self) -> str:
        return self.op_name + "_kernel"

    @property
    def op_traits(self) -> str:
        return self.op_name + "_traits"
    
    @property
    def op_kargs(self) -> str:
        return self.op_name + "_kargs"

    @property
    def k_op_host_args(self) -> str:
        return snake_to_pascal_case(self.op_name + "_host_args")
    
    @property
    def k_op_kernel(self) -> str:
        return snake_to_pascal_case(self.op_name + "_kenrel")

    @property
    def k_op_kargs(self) -> str:
        return snake_to_pascal_case(self.op_name + "_kargs")

    @property
    def k_op_pipeline(self) -> str:
        return snake_to_pascal_case(self.op_name + "_pipeline")

    @property
    def k_op_policy(self) -> str:
        return snake_to_pascal_case(self.op_name + "_policy")

    @property
    def k_op_problem(self) -> str:
        return snake_to_pascal_case(self.op_name + "_problem")

    def gen_host_api(self):
        text_ = TILE_OP_HEADER + TILE_OP_HOST_API_CPP.format(op_host_api_file = self.op_host_api_file,
                            op_host_api=self.op_host_api, op_traits=self.op_traits, op_kargs=self.op_kargs,
                            op_problem=self.op_problem, op_pipeline=self.op_pipeline, op_kernel=self.op_kernel)
        (self.base_dir / (self.op_host_api_file + ".cpp")).write_text(text_)

        text_ = TILE_OP_HEADER + TILE_OP_HOST_API_HPP.format(op_name=self.op_name, op_traits=self.op_traits, op_kargs=self.op_kargs,
                k_op_host_args=self.k_op_host_args, op_host_api=self.op_host_api)
        (self.base_dir / (self.op_host_api_file + ".hpp")).write_text(text_)

    def gen_kernel(self):
        ops = self.base_dir / 'include' / 'ck_tile' / 'ops'
        ops_op_kernel = ops / self.op_name / 'kernel'
        ops_op_pipeline = ops / self.op_name / 'pipeline'

        ops_op_kernel.mkdir(parents=True, exist_ok=True)
        ops_op_pipeline.mkdir(parents=True, exist_ok=True)

        # kernel
        text_ = TILE_OP_HEADER + TILE_OP_KERNEL.format(k_op_host_args=self.k_op_host_args,
            k_op_kernel=self.k_op_kernel, k_op_kargs=self.k_op_kargs)
        (ops_op_kernel / (self.op_name + "_kernel.hpp")).write_text(text_)

        # pipeline
        text_ = TILE_OP_HEADER + TILE_OP_PIPELINE.format(op_name=self.op_name,
            k_op_policy=self.k_op_policy, k_op_pipeline=self.k_op_pipeline)
        (ops_op_pipeline / (self.op_name + "_pipeline.hpp")).write_text(text_)

        # policy
        text_ = TILE_OP_HEADER + TILE_OP_POLICY.format(k_op_policy=self.k_op_policy)
        (ops_op_pipeline / (self.op_name + "_policy.hpp")).write_text(text_)

        # problem
        text_ = TILE_OP_HEADER + TILE_OP_PROBLEM.format(k_op_problem=self.k_op_problem)
        (ops_op_pipeline / (self.op_name + "_problem.hpp")).write_text(text_)

        # one for all header
        text_ = TILE_OP_HEADER + TILE_OP_DEVICE_HEADER.format(op_name = self.op_name)
        (ops / (self.op_name + ".hpp")).write_text(text_)

    def gen(self):
        self.gen_host_api()
        self.gen_kernel()

def gen_tile_op(args):
    name =  args.op_name.lower()
    base_dir = Path(args.directory) / name
    base_dir.mkdir(parents=True, exist_ok=True)

    op = tile_op_template(base_dir, name)
    op.gen()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="gen_tile_op",
        description="generate ck_tile op template (you still need to write kernel :))",
    )
    parser.add_argument(
        "-d",
        "--directory",
        default='./',
        required=False,
        help="where to generate the op, default is current directory"
    )
    parser.add_argument(
        "-p",
        "--op_name",
        default='foo',
        required=False,
        help="operator name to generate"
    )

    args = parser.parse_args()

    gen_tile_op(args)