test_tilelang_language_let_layout.py 4.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Test layout inference for LetStmt expressions.

This test validates that TileLang correctly handles layout inference when
fragment buffer accesses occur through let bindings. For example:

    block_mask_f = T.alloc_fragment((N_S,), T.int32)
    T.copy(BlockMask[by, :], block_mask_f)
    for i in T.Pipelined(N_S):
        a = block_mask_f[i]  # LetStmt: a is bound to fragment buffer load
        T.copy(A[a, 0], A_shared)  # a is used as index in TMA copy

Key scenarios tested:
1. Fragment buffer layout inference through let bindings
2. TMA (Tensor Memory Accelerator) copy with let-bound indices
3. CP.ASYNC copy with let-bound indices
4. Warp specialization with let-bound fragment accesses
"""

import tilelang
import tilelang.language as T
import tilelang.testing
import torch


def blocksparse_copy_kernel(M, N, N_S, block_M, block_N, dtype=T.float16):
    """BlockSparse copy kernel using fragment for block mask indices."""
    block_mask_shape = (M // block_M, N_S)

    @T.prim_func
    def main(
        A: T.Tensor((M, N), dtype),
        B: T.Tensor((M, N), dtype),
        BlockMask: T.Tensor(block_mask_shape, T.int32),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), dtype)
            B_shared = T.alloc_shared((block_M, block_N), dtype)
            block_mask_f = T.alloc_fragment((N_S,), T.int32)

            T.clear(B_shared)
            T.copy(BlockMask[by, :], block_mask_f)
            for i in T.Pipelined(N_S):
                a = block_mask_f[i]  # LetStmt: fragment buffer access
                if a >= 0:
                    T.copy(A[a, 0], A_shared)
                    T.copy(A_shared, B[by * block_M : (by + 1) * block_M, i * block_N : (i + 1) * block_N])

    return main


def ref_blocksparse_copy(A, B, BlockMask, M, N, N_S, block_M, block_N):
    """Reference implementation for blocksparse copy."""
    ref_B = B.clone()
    num_row_blocks = M // block_M

    for by in range(num_row_blocks):
        for i in range(N_S):
            src_row_start = BlockMask[by, i].item()
            ref_B[by * block_M : (by + 1) * block_M, i * block_N : (i + 1) * block_N] = A[
                src_row_start : src_row_start + block_M, 0:block_N
            ]

    return ref_B


def run_blocksparse_copy(M, N, block_M, block_N, pass_configs=None):
    """Run blocksparse copy test with given parameters."""
    N_S = N // block_N

    program = blocksparse_copy_kernel(M, N, N_S, block_M, block_N)
    kernel = tilelang.compile(
        program,
        out_idx=[1],
        target="cuda",
        pass_configs=pass_configs or {},
    )

    # Initialize tensors
    a = torch.randn(M, N, device="cuda", dtype=torch.float16)
    b = torch.zeros(M, N, device="cuda", dtype=torch.float16)

    # Create BlockMask with valid row indices
    num_row_blocks = M // block_M
    block_mask = torch.zeros((num_row_blocks, N_S), dtype=torch.int32, device="cuda")
    for by in range(num_row_blocks):
        for i in range(N_S):
            max_row_block = (M - block_M) // block_M
            block_mask[by, i] = torch.randint(0, max_row_block + 1, (1,)).item() * block_M

    # Run kernel
    c = kernel(a, block_mask)

    # Compute reference
    ref_c = ref_blocksparse_copy(a, b, block_mask, M, N, N_S, block_M, block_N)

    # Verify
    torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)


@tilelang.testing.requires_cuda
def test_blocksparse_copy_tma():
    """Test blocksparse copy with TMA (Tensor Memory Accelerator)."""
    run_blocksparse_copy(M=1024, N=1024, block_M=128, block_N=128, pass_configs={})


@tilelang.testing.requires_cuda
def test_blocksparse_copy_cp_async():
    """Test blocksparse copy with CP.ASYNC (without TMA)."""
    run_blocksparse_copy(
        M=1024,
        N=1024,
        block_M=128,
        block_N=128,
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        },
    )


if __name__ == "__main__":
    tilelang.testing.main()