test_driver.py 1.73 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pickle
import threading
import unittest

import cupy
from cupy import testing
from cupy.cuda import driver


@unittest.skipIf(cupy.cuda.runtime.is_hip, 'Context API is dperecated in HIP')
class TestDriver(unittest.TestCase):
    def test_ctxGetCurrent(self):
        # Make sure to create context.
        cupy.arange(1)
        assert 0 != driver.ctxGetCurrent()

    def test_ctxGetCurrent_thread(self):
        # Make sure to create context in main thread.
        cupy.arange(1)

        def f(self):
            self._result0 = driver.ctxGetCurrent()
            cupy.cuda.Device().use()
            cupy.arange(1)
            self._result1 = driver.ctxGetCurrent()

        self._result0 = None
        self._result1 = None
        t = threading.Thread(target=f, args=(self,))
        t.daemon = True
        t.start()
        t.join()

        # The returned context pointer must be NULL on sub thread
        # without valid context.
        assert 0 == self._result0

        # After the context is created, it should return the valid
        # context pointer.
        assert 0 != self._result1

    @testing.multi_gpu(2)
    def test_ctxGetDevice(self):
        with cupy.cuda.Device(1):
            dev = driver.ctxGetDevice()
            assert dev == 1
        with cupy.cuda.Device(0):
            dev = driver.ctxGetDevice()
            assert dev == 0

    def test_streamGetCtx(self):
        s = cupy.cuda.Stream()
        ctx = driver.streamGetCtx(s.ptr)
        ctx2 = driver.ctxGetCurrent()
        assert ctx == ctx2


class TestExceptionPicklable(unittest.TestCase):

    def test(self):
        e1 = driver.CUDADriverError(1)
        e2 = pickle.loads(pickle.dumps(e1))
        assert e1.args == e2.args
        assert str(e1) == str(e2)