test_multithreads.py 2.79 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import traceback
import threading
import multiprocessing
import numpy as np
from numba import cuda
from numba.cuda.testing import (skip_on_cudasim, skip_under_cuda_memcheck,
                                CUDATestCase)
import unittest

try:
    from concurrent.futures import ThreadPoolExecutor
except ImportError:
    has_concurrent_futures = False
else:
    has_concurrent_futures = True


has_mp_get_context = hasattr(multiprocessing, 'get_context')


def check_concurrent_compiling():
    @cuda.jit
    def foo(x):
        x[0] += 1

    def use_foo(x):
        foo[1, 1](x)
        return x

    arrays = [cuda.to_device(np.arange(10)) for i in range(10)]
    expected = np.arange(10)
    expected[0] += 1
    with ThreadPoolExecutor(max_workers=4) as e:
        for ary in e.map(use_foo, arrays):
            np.testing.assert_equal(ary, expected)


def spawn_process_entry(q):
    try:
        check_concurrent_compiling()
    # Catch anything that goes wrong in the threads
    except:  # noqa: E722
        msg = traceback.format_exc()
        q.put('\n'.join(['', '=' * 80, msg]))
    else:
        q.put(None)


@skip_under_cuda_memcheck('Hangs cuda-memcheck')
@skip_on_cudasim('disabled for cudasim')
class TestMultiThreadCompiling(CUDATestCase):

    @unittest.skipIf(not has_concurrent_futures, "no concurrent.futures")
    def test_concurrent_compiling(self):
        check_concurrent_compiling()

    @unittest.skipIf(not has_mp_get_context, "no multiprocessing.get_context")
    def test_spawn_concurrent_compilation(self):
        # force CUDA context init
        cuda.get_current_device()
        # use "spawn" to avoid inheriting the CUDA context
        ctx = multiprocessing.get_context('spawn')

        q = ctx.Queue()
        p = ctx.Process(target=spawn_process_entry, args=(q,))
        p.start()
        try:
            err = q.get()
        finally:
            p.join()
        if err is not None:
            raise AssertionError(err)
        self.assertEqual(p.exitcode, 0, 'test failed in child process')

    def test_invalid_context_error_with_d2h(self):
        def d2h(arr, out):
            out[:] = arr.copy_to_host()

        arr = np.arange(1, 4)
        out = np.zeros_like(arr)
        darr = cuda.to_device(arr)
        th = threading.Thread(target=d2h, args=[darr, out])
        th.start()
        th.join()
        np.testing.assert_equal(arr, out)

    def test_invalid_context_error_with_d2d(self):
        def d2d(dst, src):
            dst.copy_to_device(src)

        arr = np.arange(100)
        common = cuda.to_device(arr)
        darr = cuda.to_device(np.zeros(common.shape, dtype=common.dtype))
        th = threading.Thread(target=d2d, args=[darr, common])
        th.start()
        th.join()
        np.testing.assert_equal(darr.copy_to_host(), arr)


if __name__ == '__main__':
    unittest.main()