"vscode:/vscode.git/clone" did not exist on "8fbd84bf7839d53e6dd26a1dd4473dd1a99aab6e"
gemm_op.py 4.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#take in input for gemm from user, send it to example template
import enum
import ck_types
from copy import deepcopy
from dataclasses import dataclass
from enum import auto
from typing import List
from ck_types import *

class GemmType():
    GemmDefault = "ck::tensor_operation::device::GemmSpecialization::Default"

# class GemmSpecialization(enum.Enum):
#     GemmDefault = auto()
#     MNKPadding = auto()
#     MNPadding = auto()
#     MNOPadding = auto()
#     MNKOPadding = auto()


# GemmSpecializationTag = {
#     GemmSpecialization.GemmDefault: "ck::tensor_operation::device::GemmSpecialization::Default",
#     GemmSpecialization.MNKPadding: "ck::tensor_operation::device::GemmSpecialization::MNKPadding",
#     GemmSpecialization.MNPadding: "ck::tensor_operation::device::GemmSpecialization::MNPadding",
#     GemmSpecialization.MNOPadding: "ck::tensor_operation::device::GemmSpecialization::MNOPadding",
#     GemmSpecialization.MNKOPadding: "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
# }

@dataclass
class TileDesc:
    block_size: int
    m_per_block: int
    n_per_block: int
    k_per_block: int
    k1: int
    m_per_thread: int
    n_per_thread: int
    k_per_thread: int
    m1n1_thcluster_m1xs: str
    m1n1_thcluster_n1xs: str

    def __str__(self) -> str:
        values = list(self.__dict__.values())
        return "_".join([str(x) for x in values])
        return template.render(param=args)

@dataclass
class BlockTransferDesc:
    thread_slice_length: str
    thread_cluster_length: str
    thread_cluster_arrange_order: str
    src_access_order: str
    src_vec_tensor_lengths: str
    src_vec_tensor_cont_dim_order: str
    dst_vec_tensor_lengths: str

    def __str__(self) -> str:
        args = deepcopy(self.__dict__)
        args["thread_cluster_length"] = [str(x) for x in self.thread_cluster_length]
        args["thread_cluster_arrange_order"] = [
            str(x) for x in self.thread_cluster_arrange_order
        ]
        args["src_access_order"] = [str(x) for x in self.src_access_order]

@dataclass
class CBlockTransferDesc:
    src_dst_access_order: str
    src_dst_vec_dim: int
    dst_scalar_per_vector: int

    def __str__(self) -> str:
        args = deepcopy(self.__dict__)
        #args["m_n_block_wave_per_xdl"] = [str(x) for x in self.m_n_block_wave_per_xdl]


@dataclass
class GemmOperation:
    A: TensorDesc
    B: TensorDesc
    C: TensorDesc
    a_elem_op: TensorOperation
    b_elem_op: TensorOperation
    epilogue_functor: TensorOperation
    gemm_specialization: GemmType #GemmSpecialization
    tile_desc: TileDesc
    a_block_transfer: BlockTransferDesc
    b_block_transfer: BlockTransferDesc
    b1_block_transfer: BlockTransferDesc = None
    c_block_transfer: CBlockTransferDesc = None


    def __str__(self) -> str:
        io_name = "{gemm_kind}_{gemm_specialization}_{a_dtype}{b_dtype}{c_dtype}_{a_layout}{b_layout}{c_layout}".format(
            #gemm_kind=library.GemmKindNames[self.operation_kind],
            gemm_specialization=self.gemm_specialization.value,
            a_dtype=[self.A.element],
            b_dtype=[self.B.element],
            c_dtype=[self.C.element],
            a_layout=[self.A.layout],
            b_layout=[self.B.layout],
            c_layout=[self.C.layout],
        )
        extra_tile = ""
        if self.c_block_transfer is not None:
            if self.c_block_transfer.scalar_per_vector == 4:
                extra_tile = "_C4"
            elif self.c_block_transfer.scalar_per_vector == 1:
                extra_tile = "_C1"

        tile_name = str(self.tile_desc) + extra_tile
        
        return "{io_name}_{tile_name}_{epilogue_functor}".format(
            io_name=io_name,
            tile_name=tile_name,
            epilogue_functor=[self.epilogue_functor],
        )

    def accumulator_type(self):
        return DataType.f16 #f.32?

if __name__ == "__main__":
    A = TensorDesc(DataType.f16, Layout.RowMajor)
    B = TensorDesc(DataType.f16, Layout.ColumnMajor)
    C = TensorDesc(DataType.f16, Layout.RowMajor)
    GemmOp = GemmOperation(
        A=A,
        B=B,
        C=C,
        a_elem_op=TensorOperation.PassThrough,
        b_elem_op=TensorOperation.PassThrough,
        epilogue_functor=TensorOperation.PassThrough,
        gemm_specialization=GemmType.GemmDefault,
        tile_desc=TileDesc(256, 256, 128, 32, 8, 2, 32, 32, 4, 2),
        a_block_transfer=BlockTransferDesc(
            [4, 64, 1], [1, 0, 2], [1, 0, 2], 2, 8, 8, 1, True
        ),
        b_block_transfer=BlockTransferDesc(
            [8, 32, 1], [0, 2, 1], [0, 2, 1], 1, 4, 1, 0, True
        ),
        c_block_transfer=CBlockTransferDesc(1, 1, [1, 32, 1, 8], 8),
        #ds_dtype=[DataType.f16],
    )
    print(GemmOp.a_elem_op)