test_freevar.py 745 Bytes
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np

from numba import cuda
from numba.cuda.testing import unittest, CUDATestCase


class TestFreeVar(CUDATestCase):
    def test_freevar(self):
        """Make sure we can compile the following kernel with freevar reference
        in arguments to shared.array
        """
        from numba import float32

        size = 1024
        nbtype = float32

        @cuda.jit("(float32[::1], intp)")
        def foo(A, i):
            "Dummy function"
            sdata = cuda.shared.array(size,   # size is freevar
                                      dtype=nbtype)  # nbtype is freevar
            A[i] = sdata[i]

        A = np.arange(2, dtype="float32")
        foo[1, 1](A, 0)


if __name__ == '__main__':
    unittest.main()