"examples/vscode:/vscode.git/clone" did not exist on "1e69bf538774894eadfe21b76f047a515bdde322"
test_math_bitwise_reduce.py 3.95 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import tilelang
import tilelang.language as T
import torch
import tilelang.testing


@tilelang.jit(
    out_idx=[-1],
    pass_configs={
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
    },
)
def bitwise_reduce(
    M,
    N,
    block_M,
    block_N,
    name,
    func,
    clear=True,
):

    @T.prim_func
    def reduce_func(
            A: T.Tensor((M, N), "int32"),
            B: T.Tensor((M), "int32"),
            Output: T.Tensor((M), "int32"),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_N), "int32")
            A_fragment = T.alloc_fragment((block_M, block_N), "int32")
            B_shared = T.alloc_shared((block_M,), "int32")
            B_fragment = T.alloc_fragment((block_M), "int32")
            T.copy(A[by * block_M, bx * block_N], A_shared)
            T.copy(A_shared, A_fragment)
            T.copy(B[by * block_M], B_shared)
            T.copy(B_shared, B_fragment)
            func(A_fragment, B_fragment, clear=clear)
            T.copy(B_fragment, Output[by * block_M])

    return reduce_func


def run_single_bitwise_reduce(
    name,
    func,
    clear=True,
):
    M, N = 32, 32
    block_M, block_N = 32, 32
    kernel = bitwise_reduce(M, N, block_M, block_N, name, func, clear)

    # Generate test data that exercises all bit patterns for robust bitwise reduce testing
    a = torch.zeros((M, N), device="cuda", dtype=torch.int32)

    # Fill with patterns that will produce meaningful results for bitwise operations:
    # - Different bit patterns across rows/columns
    # - Mix of 0s and 1s in various positions
    # - Some all-1s and all-0s patterns for edge cases
    for i in range(M):
        for j in range(N):
            # Create varied bit patterns:
            # Row-based pattern: alternating bits based on row index
            row_pattern = (i & 0xF) << (i % 4)  # 4-bit patterns shifted by row

            # Column-based pattern: different bit positions set based on column
            col_pattern = (1 << (j % 31))  # Single bit set at different positions

            # Combine patterns with XOR to create diverse bit distributions
            # Add some deterministic "noise" based on position
            position_factor = (i * N + j) % 256

            # Final value combines all patterns
            a[i, j] = (row_pattern ^ col_pattern ^ position_factor) & 0xFFFFFFFF

            if i % 4 == 0:
                a[i, j] &= ~(0x1 << (i // 4))
            elif i % 2 == 0:
                a[i, j] |= (0x1 << (i // 2))

    if name == "reduce_bitand":
        expected = torch.full((M,), -1, device="cuda", dtype=torch.int32)
    elif name == "reduce_bitor" or name == "reduce_bitxor":
        expected = torch.full((M,), 0, device="cuda", dtype=torch.int32)
    else:
        raise ValueError("Invalid name: {}".format(name))

    output = kernel(a, expected)

    for i in range(M):
        for j in range(N):
            if name == "reduce_bitand":
                expected[i] = expected[i] & a[i, j]
            elif name == "reduce_bitor":
                expected[i] = expected[i] | a[i, j]
            elif name == "reduce_bitxor":
                expected[i] = expected[i] ^ a[i, j]
            else:
                raise ValueError("Invalid name: {}".format(name))
    assert torch.all(output == expected)
    print("✓ {} with clear={} test passed".format(name, clear))


@tilelang.testing.requires_cuda
def test_bitwise_reduce_ops():
    run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=True)
    run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=True)
    run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=True)
    run_single_bitwise_reduce("reduce_bitand", T.reduce_bitand, clear=False)
    run_single_bitwise_reduce("reduce_bitor", T.reduce_bitor, clear=False)
    run_single_bitwise_reduce("reduce_bitxor", T.reduce_bitxor, clear=False)


if __name__ == "__main__":
    tilelang.testing.main()