test_example_mla_decode.py 369 Bytes
Newer Older
1
2
3
4
5
6
7
8
import tilelang.testing

import example_mla_decode
from unittest import mock
import sys


@tilelang.testing.requires_cuda
9
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
10
11
12
13
14
15
16
def test_example_mla_decode():
    with mock.patch.object(sys, 'argv', ["example_mla_decode.py"]):
        example_mla_decode.main()


if __name__ == "__main__":
    tilelang.testing.main()