"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "b07243b4ef8e4c024a78b76df838437df8d184bc"
ir.py 1.63 KB
Newer Older
1
2
3
from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
4
import tvm_ffi
5
6
from tvm.target import Target
from tilelang import _ffi_api
7
8


9
@tvm_ffi.register_object("tl.Fill")
10
11
12
13
class Fill(Node, Scriptable):
    ...


14
@tvm_ffi.register_object("tl.AtomicAdd")
15
16
17
18
class AtomicAdd(Node, Scriptable):
    ...


19
@tvm_ffi.register_object("tl.Copy")
20
21
22
23
class Copy(Node, Scriptable):
    ...


24
@tvm_ffi.register_object("tl.Conv2DIm2Col")
25
26
27
28
class Conv2DIm2ColOp(Node, Scriptable):
    ...


29
@tvm_ffi.register_object("tl.GemmWarpPolicy")
30
class GemmWarpPolicy(Node, Scriptable):
31
32
33
34
35
36
37
38
39
    policy_type: int
    m_warp: int
    n_warp: int

    def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
                               is_wgmma: bool):
        _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
                                                    is_wgmma)
        return self.m_warp, self.n_warp
40
41


42
@tvm_ffi.register_object("tl.Gemm")
43
44
45
46
class Gemm(Node, Scriptable):
    ...


47
@tvm_ffi.register_object("tl.GemmSP")
48
49
50
51
class GemmSP(Node, Scriptable):
    ...


52
@tvm_ffi.register_object("tl.FinalizeReducerOp")
53
54
55
56
class FinalizeReducerOp(Node, Scriptable):
    ...


57
@tvm_ffi.register_object("tl.ParallelOp")
58
59
60
61
class ParallelOp(Node, Scriptable):
    ...


62
@tvm_ffi.register_object("tl.ReduceOp")
63
64
65
66
class ReduceOp(Node, Scriptable):
    ...


67
@tvm_ffi.register_object("tl.CumSumOp")
68
69
70
71
class CumSumOp(Node, Scriptable):
    ...


72
@tvm_ffi.register_object("tl.RegionOp")
73
74
75
76
class RegionOp(Node, Scriptable):
    ...


77
@tvm_ffi.register_object("tl.ReduceType")
78
79
class ReduceType(Node, Scriptable):
    ...