test_context_stack.py 4.39 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import numbers
from ctypes import byref
import weakref

from numba import cuda
from numba.cuda.testing import unittest, CUDATestCase, skip_on_cudasim
from numba.cuda.cudadrv import driver


class TestContextStack(CUDATestCase):
    def setUp(self):
        super().setUp()
        # Reset before testing
        cuda.close()

    def test_gpus_current(self):
        self.assertIs(cuda.gpus.current, None)
        with cuda.gpus[0]:
            self.assertEqual(int(cuda.gpus.current.id), 0)

    def test_gpus_len(self):
        self.assertGreater(len(cuda.gpus), 0)

    def test_gpus_iter(self):
        gpulist = list(cuda.gpus)
        self.assertGreater(len(gpulist), 0)


class TestContextAPI(CUDATestCase):

    def tearDown(self):
        super().tearDown()
        cuda.close()

    def test_context_memory(self):
        try:
            mem = cuda.current_context().get_memory_info()
        except NotImplementedError:
            self.skipTest('EMM Plugin does not implement get_memory_info()')

        self.assertIsInstance(mem.free, numbers.Number)
        self.assertEquals(mem.free, mem[0])

        self.assertIsInstance(mem.total, numbers.Number)
        self.assertEquals(mem.total, mem[1])

        self.assertLessEqual(mem.free, mem.total)

    @unittest.skipIf(len(cuda.gpus) < 2, "need more than 1 gpus")
    @skip_on_cudasim('CUDA HW required')
    def test_forbidden_context_switch(self):
        # Cannot switch context inside a `cuda.require_context`
        @cuda.require_context
        def switch_gpu():
            with cuda.gpus[1]:
                pass

        with cuda.gpus[0]:
            with self.assertRaises(RuntimeError) as raises:
                switch_gpu()

            self.assertIn("Cannot switch CUDA-context.", str(raises.exception))

    @unittest.skipIf(len(cuda.gpus) < 2, "need more than 1 gpus")
    def test_accepted_context_switch(self):
        def switch_gpu():
            with cuda.gpus[1]:
                return cuda.current_context().device.id

        with cuda.gpus[0]:
            devid = switch_gpu()
        self.assertEqual(int(devid), 1)


@skip_on_cudasim('CUDA HW required')
class Test3rdPartyContext(CUDATestCase):
    def tearDown(self):
        super().tearDown()
        cuda.close()

    def test_attached_primary(self, extra_work=lambda: None):
        # Emulate primary context creation by 3rd party
        the_driver = driver.driver
        if driver.USE_NV_BINDING:
            dev = driver.binding.CUdevice(0)
            hctx = the_driver.cuDevicePrimaryCtxRetain(dev)
        else:
            dev = 0
            hctx = driver.drvapi.cu_context()
            the_driver.cuDevicePrimaryCtxRetain(byref(hctx), dev)
        try:
            ctx = driver.Context(weakref.proxy(self), hctx)
            ctx.push()
            # Check that the context from numba matches the created primary
            # context.
            my_ctx = cuda.current_context()
            if driver.USE_NV_BINDING:
                self.assertEqual(int(my_ctx.handle), int(ctx.handle))
            else:
                self.assertEqual(my_ctx.handle.value, ctx.handle.value)

            extra_work()
        finally:
            ctx.pop()
            the_driver.cuDevicePrimaryCtxRelease(dev)

    def test_attached_non_primary(self):
        # Emulate non-primary context creation by 3rd party
        the_driver = driver.driver
        if driver.USE_NV_BINDING:
            flags = 0
            dev = driver.binding.CUdevice(0)
            hctx = the_driver.cuCtxCreate(flags, dev)
        else:
            hctx = driver.drvapi.cu_context()
            the_driver.cuCtxCreate(byref(hctx), 0, 0)
        try:
            cuda.current_context()
        except RuntimeError as e:
            # Expecting an error about non-primary CUDA context
            self.assertIn("Numba cannot operate on non-primary CUDA context ",
                          str(e))
        else:
            self.fail("No RuntimeError raised")
        finally:
            the_driver.cuCtxDestroy(hctx)

    def test_cudajit_in_attached_primary_context(self):
        def do():
            from numba import cuda

            @cuda.jit
            def foo(a):
                for i in range(a.size):
                    a[i] = i

            a = cuda.device_array(10)
            foo[1, 1](a)
            self.assertEqual(list(a.copy_to_host()), list(range(10)))

        self.test_attached_primary(do)


if __name__ == '__main__':
    unittest.main()