pack_featgraph.py 975 Bytes
Newer Older
Zhi Lin's avatar
Zhi Lin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
""" Export featgraph kernels to a shared library. """
import tvm
from sddmm import sddmm_tree_reduction_gpu


def get_sddmm_kernels_gpu(idtypes, dtypes):
    """
    Parameters
    ----------
    idtypes: List[str]
        Possible index types.
    dtypes: List[str]
        Possible data types.

    Returns
    -------
    List[IRModule]:
        The list of IRModules.
    """
    ret = []
    # SDDMM Tree Reduction
    for dtype in dtypes:
        for idtype in idtypes:
            ret.append(sddmm_tree_reduction_gpu(idtype, dtype))

    return ret


29
30
if __name__ == "__main__":
    binary_path = "libfeatgraph_kernels.so"
Zhi Lin's avatar
Zhi Lin committed
31
    kernels = []
32
33
    idtypes = ["int32", "int64"]
    dtypes = ["float16", "float64", "float32", "int32", "int64"]
Zhi Lin's avatar
Zhi Lin committed
34
35
36
37

    kernels += get_sddmm_kernels_gpu(idtypes, dtypes)

    # build kernels and export the module to libfeatgraph_kernels.so
38
    module = tvm.build(kernels, target="cuda", target_host="llvm")
Zhi Lin's avatar
Zhi Lin committed
39
    module.export_library(binary_path)