test_linear_attn.py 354 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tilelang.testing

import example_linear_attn_fwd
import example_linear_attn_bwd


@tilelang.testing.requires_cuda
def test_example_linear_attn_fwd():
    example_linear_attn_fwd.main()


@tilelang.testing.requires_cuda
def test_example_linear_attn_bwd():
    example_linear_attn_bwd.main()


if __name__ == "__main__":
    tilelang.testing.main()