test_tilelang_language_vectorize.py 3.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
import tilelang.testing
import tilelang.language as T


@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test(N, M, stride_A, stride_B):
    @T.prim_func
    def main(
10
11
        A: T.StridedTensor[(N, M), (1, stride_A), T.float32],  # noqa: F821
        B: T.StridedTensor[(N, M), (1, stride_B), T.float32],  # noqa: F821
12
13
14
15
16
17
18
19
20
21
22
23
    ):
        with T.Kernel(M // 128, threads=128) as (bx):
            tx = T.get_thread_binding(0)
            col = bx * 128 + tx

            for row in T.vectorized(N):
                B[row, col] = A[row, col]

    return main


def run_vectorize(N, M, stride_A, stride_B):
24
    assert N % 128 == 0 and M % 128 == 0
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    assert stride_A >= N and stride_B >= N

    jit_kernel = vectorize_test(N, M, stride_A, stride_B)

    base_a = torch.randn(stride_A, M, device="cuda", dtype=torch.float32)
    base_b = torch.zeros(stride_B, M, device="cuda", dtype=torch.float32)
    a = torch.as_strided(base_a, size=(N, M), stride=(1, stride_A))
    b = torch.as_strided(base_b, size=(N, M), stride=(1, stride_B))

    jit_kernel(a, b)

    torch.testing.assert_close(a, b, atol=1e-8, rtol=1e-8)

    code = jit_kernel.get_kernel_source()

    vectorize_size = 1
41
    while vectorize_size <= 2 and stride_A % (vectorize_size * 2) == 0 and stride_B % (vectorize_size * 2) == 0:
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        vectorize_size *= 2

    if vectorize_size == 4:
        assert "float4" in code
    elif vectorize_size == 2:
        assert "float2" in code


def test_vectorize():
    N, M = 512, 256

    run_vectorize(N, M, N, N)
    run_vectorize(N, M, N + 2, N + 4)
    run_vectorize(N, M, N + 4, N + 8)
    run_vectorize(N, M, N + 8, N + 16)


59
60
61
62
@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_VECTORIZE_256: True})
def vectorize_test_invariant_index(N, M, K):
    @T.prim_func
    def main(
63
64
65
        A: T.Tensor[(N, M), T.float32],  # noqa: F821
        B: T.Tensor[(N, M), T.float32],  # noqa: F821
        C: T.Tensor[(N, M // K), T.float32],  # noqa: F821
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    ):
        with T.Kernel(N // 128, threads=128) as (bx):
            tx = T.get_thread_binding(0)
            row = bx * 128 + tx

            for col in T.vectorized(M):
                B[row, col] = A[row, col] * C[row, col // K]

    return main


def run_vectorize_invariant_index(N, M, K):
    assert N % 128 == 0 and M % K == 0

    jit_kernel = vectorize_test_invariant_index(N, M, K)

    a = torch.randn(N, M, device="cuda", dtype=torch.float32)
    b = torch.zeros(N, M, device="cuda", dtype=torch.float32)
    c = torch.randn(N, M // K, device="cuda", dtype=torch.float32)

    jit_kernel(a, b, c)

    indices = torch.arange(a.size(1)) // K
    ret = a * c[:, indices]
    torch.testing.assert_close(b, ret, atol=1e-8, rtol=1e-8)

    code = jit_kernel.get_kernel_source()

    vectorize_size = 1
    while vectorize_size <= 2 and K % (vectorize_size * 2) == 0:
        vectorize_size *= 2

    if vectorize_size == 4:
        assert "float4" in code
    elif vectorize_size == 2:
        assert "float2" in code


def test_vectorize_invariant_index():
    N, M = 512, 256

    run_vectorize_invariant_index(N, M, 2)
    run_vectorize_invariant_index(N, M, 4)
    run_vectorize_invariant_index(N, M * 3, 6)
    run_vectorize_invariant_index(N, M, 8)
    run_vectorize_invariant_index(N, M * 3, 12)
    run_vectorize_invariant_index(N, M * 7, 14)


115
116
if __name__ == "__main__":
    tilelang.testing.main()