test_tilelang_language_copy.py 6.7 KB
Newer Older
1
2
3
import tilelang
import tilelang.language as T
import torch
4
import tilelang.testing
5

6
7
print(torch.__version__)

8

9
10
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
11
def tilelang_copy(M, N, block_M, block_N, src_dtype=T.float16, dst_dtype=T.float16):
12
13
    @T.prim_func
    def main(
14
15
        A: T.Tensor((M, N), src_dtype),
        B: T.Tensor((M, N), dst_dtype),
16
17
18
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
19
20
21
22
            T.copy(
                A[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N],
                B[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N],
            )
23
24
25
26

    return main


27
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
28
    program = tilelang_copy(M, N, block_M, block_N, src_dtype=dtype, dst_dtype=dtype)
29
30
31
32
    kernel = tilelang.compile(
        program,
        out_idx=[1],
        target="cuda",
33
34
        pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
    )
35
36
    source = kernel.get_kernel_source()
    print(source)
37
38
39
40
41
42
43
44
    a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a)
    torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)


def test_tilelang_copy():
    run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128)
    run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576)
45
    run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype=T.float32)
46
47


48
def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype=T.float16):
49
50
    @T.prim_func
    def main(
51
52
        A: T.StridedTensor((M, N), (NN, 1), dtype),
        B: T.Tensor((M, N), dtype),
53
54
55
56
57
58
59
60
61
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            for i, j in T.Parallel(block_M, block_N):
                B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j]

    return main


62
def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype=T.float16):
63
64
65
66
67
68
69
70
71
72
    if isinstance(NN, int):
        assert NN > N, "NN must be greater than N"
    program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
    kernel = tilelang.compile(
        program,
        out_idx=[1],
        target="cuda",
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
73
74
        },
    )
75
76
77
78
79
80
81
82
83
    if isinstance(NN, T.Var):
        NN = N * 2
    a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a[:, :N])
    torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2)


def test_tilelang_copy_with_stride():
    run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128)
84
    run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128)
85
86


87
def tilelang_copy_bufferload(num_tokens, dtype=T.float16):
88
89
    @T.prim_func
    def main(
90
        indices: T.Tensor((num_tokens,), T.int32),
91
        x: T.Tensor((num_tokens,), dtype),
92
93
    ):
        with T.Kernel(num_tokens, threads=32) as pid:
94
            idx = T.alloc_local([1], T.int32)
95
96
97
98
99
100
            T.copy(indices[pid], idx[0])
            x[idx[0]] = x[idx[0]] + 1

    return main


101
def run_tilelang_copy_bufferload(num_tokens=128, dtype=T.float16):
102
103
104
105
106
    program = tilelang_copy_bufferload(num_tokens, dtype)
    # test compilation only
    tilelang.compile(
        program,
        out_idx=[1],
107
108
        pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
    )
109
110
111
112
113
114


def test_tilelang_copy_bufferload():
    run_tilelang_copy_bufferload(num_tokens=128)


115
def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype=T.float16):
116
117
    @T.prim_func
    def main(
118
119
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
120
121
122
123
124
125
126
127
128
    ):
        # Initialize Kernel Context
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            for i, j in T.Parallel(block_M, block_N):
                T.copy(A[by * block_M + i, bx * block_N + j], B[by * block_M + i, bx * block_N + j])

    return main


129
def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
130
131
132
133
134
    program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype)
    kernel = tilelang.compile(
        program,
        out_idx=[1],
        target="cuda",
135
136
        pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
    )
137
138
139
140
141
142
143
144
145
    a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
    b = kernel(a)
    torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)


def test_tilelang_copy_buffer_load_with_parallel():
    run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128)


146
def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu):
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype)
    kernel = tilelang.compile(
        program,
        out_idx=[1],
    )
    source = kernel.get_kernel_source()
    assert "fp8_e8_t" in source
    dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8).view(torch.float8_e8m0fnu)
    output = kernel(dummy_input)
    assert output is not None


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0)
def test_tilelang_copy_fp8_e8m0():
162
    run_tilelang_copy_fp8_e8m0(src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu)
163
164


165
def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn):
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype)
    kernel = tilelang.compile(
        program,
        out_idx=[1],
    )
    source = kernel.get_kernel_source()
    assert "fp4_e2_t" in source
    # For FP4, use same shape as kernel expects, since int8 is used as storage type
    dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8)
    output = kernel(dummy_input)
    assert output is not None


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0)
def test_tilelang_copy_fp4():
182
183
184
    run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn)
    run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float16)
    run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.bfloat16)
185
186


187
188
if __name__ == "__main__":
    tilelang.testing.main()