test_tilelang_carver_generate_hints.py 2.94 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import tilelang.testing
from tilelang import carver
from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge
from tilelang.carver.arch import auto_infer_current_arch
from tvm import te


def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
    arch = auto_infer_current_arch()

    def gemm(M, N, K):
        A = te.placeholder((M, K), name='A', dtype='float16')
        B = te.placeholder((N, K), name='B', dtype='float16')

        # Describe the matrix multiplication in TE
        k = te.reduce_axis((0, K), name='k')

        C = te.compute(
            (M, N),
            lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
            name='C')

        return A, B, C

    arg1 = gemm(M, N, K)
    args = arg1

    func = te.create_prim_func(args)

    tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target)
    print(tags)
    policy = carver.TensorCorePolicy.from_prim_func(
        func=tensorized_func, arch=arch, tags=tags, name="matmul_0")

    hints = policy.emit_config(topk=topk)

    for hint in hints:
        print(hint)

    assert len(hints) > 0, "Hints length is zero"

    prim_func_node = PrimFuncNode(tensorized_func, name="matmul_1")
    output_nodes = [OutputNode(prim_func_node)]
    policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags)

    hints = policy.emit_config(topk=10)

    for config in hints:
        print(config)

    assert len(hints) > 0, "Hints length is zero"


def test_general_matmul_emit_configs():
    run_general_matmul_emit_configs(128, 128, 128)


def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
    arch = auto_infer_current_arch()

    def gemm(M, N, K):
        A = te.placeholder((M, K), name='A', dtype='float16')
        B = te.placeholder((N, K), name='B', dtype='float16')

        # Describe the matrix multiplication in TE
        k = te.reduce_axis((0, K), name='k')

        C = te.compute(
            (M, N),
            lambda i, j: te.sum(A[i, k].astype('float16') * B[j, k].astype('float16'), axis=[k]),
            name='C')

        return A, B, C

    arg1 = gemm(M, N, K)
    args = arg1

    func = te.create_prim_func(args)

    tensorized_func, tags = carver.utils.get_tensorized_func_and_tags(func, arch.target)
    print(tags)

    node_0 = PrimFuncNode(tensorized_func, name="matmul_0")
    node_1 = PrimFuncNode(tensorized_func, name="matmul_1")

    edge = Edge(node_0, node_1, 0, 0)
    node_0._out_edges.append(edge)
    node_1.set_inputs(0, edge)

    output_nodes = [OutputNode(node_1)]
    policy = carver.TensorCorePolicy.from_output_nodes(output_nodes, arch=arch, tags=tags)

    hints = policy.emit_config(topk=topk)

    for config in hints:
        print(config)

    assert len(hints) > 0, "Hints length is zero"


def test_general_matmul_matmul_emit_configs():
    run_general_matmul_matmul_emit_configs(128, 128, 128)


if __name__ == "__main__":
    tilelang.testing.main()