test_ptds.py 4.83 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import multiprocessing as mp
import logging
import traceback
from numba.cuda.testing import unittest, CUDATestCase
from numba.cuda.testing import (skip_on_cudasim, skip_with_cuda_python,
                                skip_under_cuda_memcheck)
from numba.tests.support import linux_only


def child_test():
    from numba import cuda, int32, void
    from numba.core import config
    import io
    import numpy as np
    import threading

    # Enable PTDS before we make any CUDA driver calls.  Enabling it first
    # ensures that PTDS APIs are used because the CUDA driver looks up API
    # functions on first use and memoizes them.
    config.CUDA_PER_THREAD_DEFAULT_STREAM = 1

    # Set up log capture for the Driver API so we can see what API calls were
    # used.
    logbuf = io.StringIO()
    handler = logging.StreamHandler(logbuf)
    cudadrv_logger = logging.getLogger('numba.cuda.cudadrv.driver')
    cudadrv_logger.addHandler(handler)
    cudadrv_logger.setLevel(logging.DEBUG)

    # Set up data for our test, and copy over to the device
    N = 2 ** 16
    N_THREADS = 10
    N_ADDITIONS = 4096

    # Seed the RNG for repeatability
    np.random.seed(1)
    x = np.random.randint(low=0, high=1000, size=N, dtype=np.int32)
    r = np.zeros_like(x)

    # One input and output array for each thread
    xs = [cuda.to_device(x) for _ in range(N_THREADS)]
    rs = [cuda.to_device(r) for _ in range(N_THREADS)]

    # Compute the grid size and get the [per-thread] default stream
    n_threads = 256
    n_blocks = N // n_threads
    stream = cuda.default_stream()

    # A simple multiplication-by-addition kernel. What it does exactly is not
    # too important; only that we have a kernel that does something.
    @cuda.jit(void(int32[::1], int32[::1]))
    def f(r, x):
        i = cuda.grid(1)

        if i > len(r):
            return

        # Accumulate x into r
        for j in range(N_ADDITIONS):
            r[i] += x[i]

    # This function will be used to launch the kernel from each thread on its
    # own unique data.
    def kernel_thread(n):
        f[n_blocks, n_threads, stream](rs[n], xs[n])

    # Create threads
    threads = [threading.Thread(target=kernel_thread, args=(i,))
               for i in range(N_THREADS)]

    # Start all threads
    for thread in threads:
        thread.start()

    # Wait for all threads to finish, to ensure that we don't synchronize with
    # the device until all kernels are scheduled.
    for thread in threads:
        thread.join()

    # Synchronize with the device
    cuda.synchronize()

    # Check output is as expected
    expected = x * N_ADDITIONS
    for i in range(N_THREADS):
        np.testing.assert_equal(rs[i].copy_to_host(), expected)

    # Return the driver log output to the calling process for checking
    handler.flush()
    return logbuf.getvalue()


def child_test_wrapper(result_queue):
    try:
        output = child_test()
        success = True
    # Catch anything raised so it can be propagated
    except: # noqa: E722
        output = traceback.format_exc()
        success = False

    result_queue.put((success, output))


# Run on Linux only until the reason for test hangs on Windows (Issue #8635,
# https://github.com/numba/numba/issues/8635) is diagnosed
@linux_only
@skip_under_cuda_memcheck('Hangs cuda-memcheck')
@skip_on_cudasim('Streams not supported on the simulator')
class TestPTDS(CUDATestCase):
    @skip_with_cuda_python('Function names unchanged for PTDS with NV Binding')
    def test_ptds(self):
        # Run a test with PTDS enabled in a child process
        ctx = mp.get_context('spawn')
        result_queue = ctx.Queue()
        proc = ctx.Process(target=child_test_wrapper, args=(result_queue,))
        proc.start()
        proc.join()
        success, output = result_queue.get()

        # Ensure the child process ran to completion before checking its output
        if not success:
            self.fail(output)

        # Functions with a per-thread default stream variant that we expect to
        # see in the output
        ptds_functions = ('cuMemcpyHtoD_v2_ptds', 'cuLaunchKernel_ptsz',
                          'cuMemcpyDtoH_v2_ptds')

        for fn in ptds_functions:
            with self.subTest(fn=fn, expected=True):
                self.assertIn(fn, output)

        # Non-PTDS versions of the functions that we should not see in the
        # output:
        legacy_functions = ('cuMemcpyHtoD_v2', 'cuLaunchKernel',
                            'cuMemcpyDtoH_v2')

        for fn in legacy_functions:
            with self.subTest(fn=fn, expected=False):
                # Ensure we only spot these function names appearing without a
                # _ptds or _ptsz suffix by checking including the end of the
                # line in the log
                fn_at_end = f'{fn}\n'
                self.assertNotIn(fn_at_end, output)


if __name__ == '__main__':
    unittest.main()