""" The compute function and schedules for SDDMM kernels written in TVM. """ import tvm from tvm import te def sddmm_tree_reduction_gpu(idx_type, feat_type): """SDDMM kernels on GPU optimized with Tree Reduction. Parameters ---------- idx_type : str The data type for indexing tensors. feat_type : str The data type of feature tensor. Returns ------- IRModule The result IRModule. """ # define vars and placeholders nnz = te.var("nnz", idx_type) num_rows = te.var("num_rows", idx_type) num_cols = te.var("num_cols", idx_type) H = te.var("num_heads", idx_type) D = te.var("feat_len", idx_type) row = te.placeholder((nnz,), idx_type, "row") col = te.placeholder((nnz,), idx_type, "col") ufeat = te.placeholder((num_rows, H, D), feat_type, "ufeat") vfeat = te.placeholder((num_cols, H, D), feat_type, "vfeat") # define edge computation function def edge_func(eid, h, i): k = te.reduce_axis((0, D), name="k") return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k) out = te.compute( (nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name="out" ) # define schedules sched = te.create_schedule(out.op) edge_axis, head_axis, _ = out.op.axis reduce_axis = out.op.reduce_axis[0] _, red_inner = sched[out].split(reduce_axis, factor=32) edge_outer, edge_inner = sched[out].split(edge_axis, factor=32) sched[out].bind(red_inner, te.thread_axis("threadIdx.x")) sched[out].bind(edge_inner, te.thread_axis("threadIdx.y")) sched[out].bind(edge_outer, te.thread_axis("blockIdx.x")) sched[out].bind(head_axis, te.thread_axis("blockIdx.y")) return tvm.lower( sched, [row, col, ufeat, vfeat, out], name="SDDMMTreeReduction_{}_{}".format(idx_type, feat_type), ) if __name__ == "__main__": kernel0 = sddmm_tree_reduction_gpu("int32", "float32") print(kernel0)