"vscode:/vscode.git/clone" did not exist on "226aaef9f43c24df758df9791568d9c49f4c5a6e"
test_tilelang_carver_recommend_hints.py 4.45 KB
Newer Older
1
2
import tilelang.testing
from tilelang import carver
3
from tilelang.language import dtypes as T
4
5
6
7
from tilelang.carver.arch import auto_infer_current_arch
from typing import List


8
def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20):
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
    arch = auto_infer_current_arch()
    carve_template = carver.GeneralReductionTemplate(
        structure=structure,
        shape=shape,
        dtype=dtype,
    ).with_arch(arch)

    func = carve_template.equivalent_function()
    assert func is not None, "Function is None"

    hints = carve_template.recommend_hints(topk=topk)
    assert len(hints) > 0, "Hints length is zero"


def test_general_reduction_recommend_hints():
24
25
26
    run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], T.float16)
    run_general_reduction_recommend_hints("SS", [1024, 1024], T.float16)
    run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], T.float16)
27
28


29
def run_elementwise_recommend_hints(shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20):
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    arch = auto_infer_current_arch()
    carve_template = carver.ElementwiseTemplate(
        shape=shape,
        dtype=dtype,
    ).with_arch(arch)

    func = carve_template.equivalent_function()
    assert func is not None, "Function is None"

    hints = carve_template.recommend_hints(topk=topk)
    assert len(hints) > 0, "Hints length is not topk"


def test_elementwise_recommend_hints():
44
45
46
    run_elementwise_recommend_hints([1024, 1024], T.float16)
    run_elementwise_recommend_hints([1024], T.float16)
    run_elementwise_recommend_hints([1024, 1024, 1024], T.float16)
47
48
49
50
51
52


def run_matmul_recommend_hints(
    M: int = 1024,
    N: int = 1024,
    K: int = 1024,
53
54
55
    in_dtype: T.dtype = T.float16,
    out_dtype: T.dtype = T.float16,
    accum_dtype: T.dtype = T.float16,
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
):
    arch = auto_infer_current_arch()
    carve_template = carver.MatmulTemplate(
        M=M,
        N=N,
        K=K,
        in_dtype=in_dtype,
        out_dtype=out_dtype,
        accum_dtype=accum_dtype,
    ).with_arch(arch)

    func = carve_template.equivalent_function()
    assert func is not None, "Function is None"

    hints = carve_template.recommend_hints(topk=20)
    assert len(hints) > 0, "Hints length is not 20"


def test_matmul_recommend_hints():
75
76
77
    run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float16, T.float16)
    run_matmul_recommend_hints(1024, 1024, 1024, T.int8, T.int32, T.int32)
    run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float32, T.float16)
78
79


80
def run_gemv_recommend_hints(
81
    N: int = 1024, K: int = 1024, in_dtype: T.dtype = T.float16, out_dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float16
82
):
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    arch = auto_infer_current_arch()
    carve_template = carver.GEMVTemplate(
        N=N,
        K=K,
        in_dtype=in_dtype,
        out_dtype=out_dtype,
        accum_dtype=accum_dtype,
    ).with_arch(arch)

    func = carve_template.equivalent_function()
    assert func is not None, "Function is None"

    hints = carve_template.recommend_hints(topk=20)
    assert len(hints) > 0, "Hints length is not 20"


def test_gemv_recommend_hints():
100
101
102
    run_gemv_recommend_hints(1024, 1024, T.float16, T.float16, T.float16)
    run_gemv_recommend_hints(1024, 1024, T.int8, T.int32, T.int32)
    run_gemv_recommend_hints(1024, 1024, T.float16, T.float32, T.float16)
103
104


105
106
107
108
109
110
def run_fmha_recommend_hints(
    batch_size: int = 4,
    num_heads: int = 32,
    seq_length: int = 512,
    seq_kv_length: int = 512,
    head_dim: int = 128,
111
112
113
    in_dtype: T.dtype = T.float16,
    accum_dtype: T.dtype = T.float16,
    out_dtype: T.dtype = T.float16,
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
):
    arch = auto_infer_current_arch()
    carve_template = carver.FlashAttentionTemplate(
        batch_size=batch_size,
        num_heads=num_heads,
        seq_length=seq_length,
        seq_kv_length=seq_kv_length,
        head_dim=head_dim,
        in_dtype=in_dtype,
        accum_dtype=accum_dtype,
        out_dtype=out_dtype,
    ).with_arch(arch)

    func = carve_template.equivalent_function()
    assert func is not None, "Function is None"

    hints = carve_template.recommend_hints(topk=20)
    for hint in hints:
        print(hint)
    assert len(hints) > 0, "Hints length should be greater than 0"


def test_fmha_recommend_hints():
137
138
    run_fmha_recommend_hints(4, 32, 512, 512, 128, T.float16, T.float16, T.float16)
    run_fmha_recommend_hints(4, 32, 512, 512, 128, T.int8, T.int32, T.int32)
139
140


141
142
if __name__ == "__main__":
    tilelang.testing.main()