user.py 7.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gemm_op as gemm
import enum
from dataclasses import dataclass
from enum import auto
import ck_types
from ck_types import *

def CreateGemmOperator():
    #operation_kind = library.GemmKind.Gemm
    a_element_desc = TensorDesc(
       DataType.f16, Layout.ColumnMajor
    )
    b_element_desc = TensorDesc(
        DataType.f16, Layout.RowMajor
    )
    c_element_desc = TensorDesc(
        DataType.f16,Layout.RowMajor
    )
    element_op = TensorOperation.PassThrough

    tile_descriptions = [
        gemm.TileDesc(256, 128, 128, 16, 2, 4, 4, 1, "S<8, 2>", "S<8, 2>"),
        gemm.TileDesc(256, 128, 128, 8, 2, 4, 4, 1, "S<8, 2>", "S<8, 2>"),
        gemm.TileDesc(128, 64, 128, 8, 2, 4, 4, 1, "S<4, 2>", "S<8, 2>"),
        gemm.TileDesc(128, 128, 64, 8, 2, 4, 4, 1, "S<8, 2>", "S<4, 2>"),
        gemm.TileDesc(256, 64, 128, 8, 2, 2, 4, 1, "S<8, 2>", "S<8, 2>"),
        gemm.TileDesc(256, 128, 64, 8, 2, 4, 2, 1, "S<8, 2>", "S<8, 2>"),
        gemm.TileDesc(128, 32, 128, 8, 2, 2, 4, 1, "S<4, 2>", "S<8, 2>"),
        gemm.TileDesc(128, 128, 32, 8, 2, 4, 2, 1, "S<8, 2>", "S<4, 2>"),
        gemm.TileDesc(128, 32, 64, 8, 2, 2, 2, 1, "S<4, 2>", "S<8, 2>"),
        gemm.TileDesc(128, 64, 32, 8, 2, 2, 2, 1, "S<8, 2>", "S<4, 2>"),
    ]

    a_block_descriptions = [
        gemm.BlockTransferDesc("S<2, 1, 4, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 8, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 2, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 2, 1>", "S<0, 3, 1, 2>", "S<1, 1, 2, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 2, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 2, 1>", "S<0, 3, 1, 2>", "S<1, 1, 2, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 8, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 2, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 2, 1>", "S<0, 3, 1, 2>", "S<1, 1, 2, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        
    ]

    b_block_descriptions = [
        gemm.BlockTransferDesc("S<2, 1, 4, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 8, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 2, 2>", "S<8, 1, 32, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 2, 1>", "S<0, 3, 1, 2>", "S<1, 1, 2, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 8, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 2, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 2, 1>", "S<0, 3, 1, 2>", "S<1, 1, 2, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 4, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 4, 1>", "S<0, 3, 1, 2>", "S<1, 1, 4, 2>"),
        gemm.BlockTransferDesc("S<1, 1, 2, 2>", "S<8, 1, 16, 1>", "S<0, 3, 1, 2>", "S<0, 3, 1, 2>", "S<1, 1, 2, 1>", "S<0, 3, 1, 2>", "S<1, 1, 2, 2>"),
    ]

    c_block_descriptions = [
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 2),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 2),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 4),
        gemm.CBlockTransferDesc("S<0, 1, 2, 3, 4, 5>", 5, 2),
    ]
    #a_block_descriptions = b_block_descriptions
    #c_block_descriptions = []
    # AIT logic, adapt later
    # for t in tile_descriptions:
    #     a_block_transfer = -1
    #     c_block_transfer = -1
    #     if t.block_size == 256:
    #         a_block_transfer = [4, 64, 1]
    #         c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8)
    #     if t.block_size == 128:
    #         a_block_transfer = [4, 32, 1]
    #         if t.n_per_block == 128:
    #             c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 16, 1, 8], 8)
    #         if t.n_per_block == 64:
    #             c_block_transfer = gemm.CBlockTransferDesc(1, 1, [1, 32, 1, 4], 8)

    #     assert (
    #         a_block_transfer != -1
    #         and c_block_transfer != -1
    #         and "Cannot determine block_transfer_size with block_size "
    #         + str(t.block_size)
    #     )
    #     a_block_descriptions.append(
    #         gemm.BlockTransferDesc(a_block_transfer, [1, 0, 2], [1, 0, 2], 2, 8, 8, 1)
    #     )
    #     c_block_descriptions.append(c_block_transfer)

    gemm_specialization = [
        gemm.GemmType.GemmDefault
    ]
    operations = []
    for gemm_spec in gemm_specialization:
        for tile_desc, a_block_desc, b_block_desc, c_block_desc in zip(
            tile_descriptions,
            a_block_descriptions,
            b_block_descriptions,
            c_block_descriptions,
        ):
            new_operation = gemm.GemmOperation(
                #operation_kind=operation_kind,
                A=a_element_desc,
                B=b_element_desc,
                C=c_element_desc,
                a_elem_op=element_op,
                b_elem_op=element_op,
                epilogue_functor=element_op,
                gemm_specialization=gemm_spec,
                tile_desc=tile_desc,
                a_block_transfer=a_block_desc,
                b_block_transfer=b_block_desc,
                c_block_transfer=c_block_desc,
            )
            #manifest.append(new_operation)
            operations.append(new_operation)
    return operations

    print (operations[0].tile_desc)