import torch import tilelang import tilelang.testing from tilelang import language as T @tilelang.jit( pass_configs={ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, },) def _cumsum_view_infer_layout(hidden): num_tokens = T.dynamic('num_tokens') @T.prim_func def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']): with T.Kernel(num_tokens, threads=128) as pid: smem = T.alloc_shared((hidden,), dtype='float') T.copy(x[pid, :], smem) T.cumsum(T.view(smem, (1, hidden)), dim=1) return buggy_kernel def test_cumsum_view_infer_layout(): hidden = 128 x = torch.randn(1, hidden, device='cuda', dtype=torch.float) kernel = _cumsum_view_infer_layout(hidden) kernel(x) if __name__ == '__main__': tilelang.testing.main()