remod.py 4.16 KB
Newer Older
rocking's avatar
rocking committed
1
from datetime import datetime
carlushuang's avatar
carlushuang committed
2
3
4
5
6
7
8
9
import pathlib
from pathlib import Path
import subprocess
import os
import copy

NS = 'ck_tile'
OPS = 'ops'
carlushuang's avatar
carlushuang committed
10
REF = 'ref'
carlushuang's avatar
carlushuang committed
11
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
12
13
14
DEVICE = 'device'
HOST = 'host'
UTIL = 'util'
carlushuang's avatar
carlushuang committed
15

rocking's avatar
rocking committed
16
17
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n
carlushuang's avatar
carlushuang committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""

# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
def get_module(f, level = 0):
    all_parts = f.parts
    return str(all_parts[level])

all_files = []
for p in sorted(Path("./").rglob("*")):
    if p.suffix == '.hpp':
        all_files.append(pathlib.PurePath(p))

class submodule_t:
    def __init__(self):
        self.m = dict()
    def push(self, f):
34
        if len(f.parents) != 2: # ignore device/xxx.hpp and host/xxx.hpp
carlushuang's avatar
carlushuang committed
35
            mod = get_module(f)
36
            if mod == HOST or mod == DEVICE:
carlushuang's avatar
carlushuang committed
37
38
39
                if mod not in self.m.keys():
                    self.m[mod] = dict()
                mod2 = get_module(f, 1)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
                if mod2 == UTIL:
                    if mod2 not in self.m[mod].keys():
                        self.m[mod][mod2] = dict()
                    mod3 = get_module(f, 2)
                    if Path(mod3).suffix != '.hpp':
                        # ignore util/xxx.hpp
                        if mod3 not in self.m[mod][mod2].keys():
                            self.m[mod][mod2][mod3] = list()
                        self.m[mod][mod2][mod3].append(f)

                # ref is supposed to include one header on demand TODO: ref moved to host/reference
                elif mod2 == REF:
                    return
                elif mod2 == OPS:
                    if mod2 not in self.m[mod].keys():
                        self.m[mod][mod2] = dict()
                    mod3 = get_module(f, 2)
                    if Path(mod3).suffix != '.hpp':
                        # ignore ops/xxx.hpp
                        if mod3 not in self.m[mod][mod2].keys():
                            self.m[mod][mod2][mod3] = list()
                        self.m[mod][mod2][mod3].append(f)
                else:
carlushuang's avatar
carlushuang committed
63
64
65
66
67
68
69
70
71
72
                    if mod2 not in self.m[mod].keys():
                        self.m[mod][mod2] = list()
                    self.m[mod][mod2].append(f)
            else:
                if mod not in self.m.keys():
                    self.m[mod] = list()
                self.m[mod].append(f)

    def gen(self):
        def gen_header(hpath, include_list):
73
            #print(hpath)
carlushuang's avatar
carlushuang committed
74
75
76
77
78
79
80
81
82
83
            if os.path.exists(str(hpath)):
                os.remove(str(hpath))
            with hpath.open('w') as f:
                f.write(HEADER_COMMON)
                f.write('#pragma once\n')
                f.write('\n')
                for individual_header in include_list:
                    header_path = NS + '/' + str(individual_header)
                    f.write(f'#include \"{header_path}\"\n')
                # f.write('\n') # otherwise clang-format will complain
84
        #print(self.m)
carlushuang's avatar
carlushuang committed
85
        # restructure common
86
87
88
89
90
91
92
93
94
        for k0, _ in self.m.items():
            if k0 == DEVICE or k0 == HOST:
                for k, v in self.m[k0].items():
                    if k == OPS and OPS_COMMON in v.keys():
                        common_list = copy.deepcopy(v[OPS_COMMON])
                        # v.pop(OPS_COMMON)
                        for km in v.keys():
                            if km != OPS_COMMON:
                                v[km].extend(common_list)
carlushuang's avatar
carlushuang committed
95

96
97
98
99
100
101
102
103
        for k0, _ in self.m.items():
            if k0 == DEVICE or k0 == HOST:
                for k, v in self.m[k0].items():
                    if k == OPS:
                        for km, kv in v.items():
                            gen_header(Path(k0 + '/' + k) / (f'{km}.hpp'), kv)
                    else:
                        gen_header(Path(f'{k0}/{k}.hpp'), v)
carlushuang's avatar
carlushuang committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            

submodule = submodule_t()
# formatting
for x in all_files:
    subprocess.Popen(f'dos2unix {str(x)}', shell=True)
    cmd = f'clang-format-12 -style=file -i {str(x)}'
    #for xp in x.parents:
    #print(get_file_base(x))
    subprocess.Popen(cmd, shell=True)
    submodule.push(x)

submodule.gen()

#print(all_files)