"examples/vscode:/vscode.git/clone" did not exist on "131fb2c192aa066ce22edd92c9e534a555566327"
ir.py 2.12 KB
Newer Older
1
2
3
from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
4
import tvm_ffi
5
6
from tvm.target import Target
from tilelang import _ffi_api
7
8


9
@tvm_ffi.register_object("tl.Fill")
10
11
12
13
class Fill(Node, Scriptable):
    ...


14
@tvm_ffi.register_object("tl.AtomicAdd")
15
16
17
18
class AtomicAdd(Node, Scriptable):
    ...


19
@tvm_ffi.register_object("tl.Copy")
20
21
22
23
class Copy(Node, Scriptable):
    ...


24
@tvm_ffi.register_object("tl.Conv2DIm2Col")
25
26
27
28
class Conv2DIm2ColOp(Node, Scriptable):
    ...


29
@tvm_ffi.register_object("tl.GemmWarpPolicy")
30
class GemmWarpPolicy(Node, Scriptable):
31
32
33
34
35
36
37
38
39
    policy_type: int
    m_warp: int
    n_warp: int

    def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
                               is_wgmma: bool):
        _ffi_api.GemmWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
                                                    is_wgmma)
        return self.m_warp, self.n_warp
40
41


42
43
44
45
46
47
48
49
50
51
52
53
54
@tvm_ffi.register_object("tl.GemmSPWarpPolicy")
class GemmSPWarpPolicy(Node, Scriptable):
    policy_type: int
    m_warp: int
    n_warp: int

    def compute_warp_partition(self, M: int, N: int, block_size: int, target: Target,
                               is_wgmma: bool, bits: int):
        _ffi_api.GemmSPWarpPolicyComputeWarpPartition(self, int(M), int(N), int(block_size), target,
                                                      is_wgmma, bits)
        return self.m_warp, self.n_warp


55
@tvm_ffi.register_object("tl.Gemm")
56
57
58
59
class Gemm(Node, Scriptable):
    ...


60
@tvm_ffi.register_object("tl.GemmSP")
61
62
63
64
class GemmSP(Node, Scriptable):
    ...


65
@tvm_ffi.register_object("tl.FinalizeReducerOp")
66
67
68
69
class FinalizeReducerOp(Node, Scriptable):
    ...


70
@tvm_ffi.register_object("tl.ParallelOp")
71
72
73
74
class ParallelOp(Node, Scriptable):
    ...


75
@tvm_ffi.register_object("tl.ReduceOp")
76
77
78
79
class ReduceOp(Node, Scriptable):
    ...


80
@tvm_ffi.register_object("tl.CumSumOp")
81
82
83
84
class CumSumOp(Node, Scriptable):
    ...


85
@tvm_ffi.register_object("tl.RegionOp")
86
87
88
89
class RegionOp(Node, Scriptable):
    ...


90
@tvm_ffi.register_object("tl.ReduceType")
91
92
class ReduceType(Node, Scriptable):
    ...