test_tilelang_issue_1001.py 892 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
11
12
    },
)
13
def _cumsum_view_infer_layout(hidden):
14
    num_tokens = T.dynamic("num_tokens")
15
16

    @T.prim_func
17
    def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]):
18
        with T.Kernel(num_tokens, threads=128) as pid:
19
            smem = T.alloc_shared((hidden,), dtype="float")
20
21
22
23
24
25
26
27
            T.copy(x[pid, :], smem)
            T.cumsum(T.view(smem, (1, hidden)), dim=1)

    return buggy_kernel


def test_cumsum_view_infer_layout():
    hidden = 128
28
    x = torch.randn(1, hidden, device="cuda", dtype=torch.float)
29
30
31
32
    kernel = _cumsum_view_infer_layout(hidden)
    kernel(x)


33
if __name__ == "__main__":
34
    tilelang.testing.main()