test_forall.py 1.42 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np

from numba import cuda
import unittest
from numba.cuda.testing import CUDATestCase


@cuda.jit
def foo(x):
    i = cuda.grid(1)
    if i < x.size:
        x[i] += 1


class TestForAll(CUDATestCase):
    def test_forall_1(self):
        arr = np.arange(11)
        orig = arr.copy()
        foo.forall(arr.size)(arr)
        np.testing.assert_array_almost_equal(arr, orig + 1)

    def test_forall_2(self):
        @cuda.jit("void(float32, float32[:], float32[:])")
        def bar(a, x, y):
            i = cuda.grid(1)
            if i < x.size:
                y[i] = a * x[i] + y[i]

        x = np.arange(13, dtype=np.float32)
        y = np.arange(13, dtype=np.float32)
        oldy = y.copy()
        a = 1.234
        bar.forall(y.size)(a, x, y)
        np.testing.assert_array_almost_equal(y, a * x + oldy, decimal=3)

    def test_forall_no_work(self):
        # Ensure that forall doesn't launch a kernel with no blocks when called
        # with 0 elements. See Issue #5017.
        arr = np.arange(11)
        foo.forall(0)(arr)

    def test_forall_negative_work(self):
        # Ensure that forall doesn't allow the creation of a forall with a
        # negative element count.
        with self.assertRaises(ValueError) as raises:
            foo.forall(-1)
        self.assertIn("Can't create ForAll with negative task count",
                      str(raises.exception))


if __name__ == '__main__':
    unittest.main()