test_iterators.py 2.29 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from numba import cuda
from numba.cuda.testing import unittest, CUDATestCase

import numpy as np


class TestIterators(CUDATestCase):

    def test_enumerate(self):
        @cuda.jit
        def enumerator(x, error):
            count = 0

            for i, v in enumerate(x):
                if count != i:
                    error[0] = 1
                if v != x[i]:
                    error[0] = 2

                count += 1

            if count != len(x):
                error[0] = 3

        x = np.asarray((10, 9, 8, 7, 6))
        error = np.zeros(1, dtype=np.int32)

        enumerator[1, 1](x, error)
        self.assertEqual(error[0], 0)

    def _test_twoarg_function(self, f):
        x = np.asarray((10, 9, 8, 7, 6))
        y = np.asarray((1, 2, 3, 4, 5))
        error = np.zeros(1, dtype=np.int32)

        f[1, 1](x, y, error)
        self.assertEqual(error[0], 0)

    def test_zip(self):
        @cuda.jit
        def zipper(x, y, error):
            i = 0

            for xv, yv in zip(x, y):
                if xv != x[i]:
                    error[0] = 1
                if yv != y[i]:
                    error[0] = 2

                i += 1

            if i != len(x):
                error[0] = 3

        self._test_twoarg_function(zipper)

    def test_enumerate_zip(self):
        @cuda.jit
        def enumerator_zipper(x, y, error):
            count = 0

            for i, (xv, yv) in enumerate(zip(x, y)):
                if i != count:
                    error[0] = 1
                if xv != x[i]:
                    error[0] = 2
                if yv != y[i]:
                    error[0] = 3

                count += 1

            if count != len(x):
                error[0] = 4

        self._test_twoarg_function(enumerator_zipper)

    def test_zip_enumerate(self):
        @cuda.jit
        def zipper_enumerator(x, y, error):
            count = 0

            for (i, xv), yv in zip(enumerate(x), y):
                if i != count:
                    error[0] = 1
                if xv != x[i]:
                    error[0] = 2
                if yv != y[i]:
                    error[0] = 3

                count += 1

            if count != len(x):
                error[0] = 4

        self._test_twoarg_function(zipper_enumerator)


if __name__ == '__main__':
    unittest.main()