sddmm.py 1.93 KB
Newer Older
Zhi Lin's avatar
Zhi Lin committed
1
2
3
4
5
6
""" The compute function and schedules for SDDMM kernels written in TVM. """
import tvm
from tvm import te


def sddmm_tree_reduction_gpu(idx_type, feat_type):
7
8
    """SDDMM kernels on GPU optimized with Tree Reduction.

Zhi Lin's avatar
Zhi Lin committed
9
10
11
12
13
14
15
16
17
18
19
20
21
    Parameters
    ----------
    idx_type : str
        The data type for indexing tensors.
    feat_type : str
        The data type of feature tensor.

    Returns
    -------
    IRModule
        The result IRModule.
    """
    # define vars and placeholders
22
23
24
25
26
27
28
29
30
    nnz = te.var("nnz", idx_type)
    num_rows = te.var("num_rows", idx_type)
    num_cols = te.var("num_cols", idx_type)
    H = te.var("num_heads", idx_type)
    D = te.var("feat_len", idx_type)
    row = te.placeholder((nnz,), idx_type, "row")
    col = te.placeholder((nnz,), idx_type, "col")
    ufeat = te.placeholder((num_rows, H, D), feat_type, "ufeat")
    vfeat = te.placeholder((num_cols, H, D), feat_type, "vfeat")
Zhi Lin's avatar
Zhi Lin committed
31
32
    # define edge computation function
    def edge_func(eid, h, i):
33
        k = te.reduce_axis((0, D), name="k")
Zhi Lin's avatar
Zhi Lin committed
34
        return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k)
35
36
37
38

    out = te.compute(
        (nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name="out"
    )
Zhi Lin's avatar
Zhi Lin committed
39
40
41
42
43
44
    # define schedules
    sched = te.create_schedule(out.op)
    edge_axis, head_axis, _ = out.op.axis
    reduce_axis = out.op.reduce_axis[0]
    _, red_inner = sched[out].split(reduce_axis, factor=32)
    edge_outer, edge_inner = sched[out].split(edge_axis, factor=32)
45
46
47
48
49
50
51
52
53
    sched[out].bind(red_inner, te.thread_axis("threadIdx.x"))
    sched[out].bind(edge_inner, te.thread_axis("threadIdx.y"))
    sched[out].bind(edge_outer, te.thread_axis("blockIdx.x"))
    sched[out].bind(head_axis, te.thread_axis("blockIdx.y"))
    return tvm.lower(
        sched,
        [row, col, ufeat, vfeat, out],
        name="SDDMMTreeReduction_{}_{}".format(idx_type, feat_type),
    )
Zhi Lin's avatar
Zhi Lin committed
54
55


56
57
if __name__ == "__main__":
    kernel0 = sddmm_tree_reduction_gpu("int32", "float32")
Zhi Lin's avatar
Zhi Lin committed
58
    print(kernel0)