test_init.py 4.06 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import operator
import os
import shutil
import subprocess
import sys
import tempfile
import unittest
from unittest import mock

import numpy
import pytest

import cupy
import cupyx


def _run_script(code):
    # subprocess is required not to interfere with cupy module imported in top
    # of this file
    temp_dir = tempfile.mkdtemp()
    try:
        script_path = os.path.join(temp_dir, 'script.py')
        with open(script_path, 'w') as f:
            f.write(code)
        proc = subprocess.Popen(
            [sys.executable, script_path],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE)
        stdoutdata, stderrdata = proc.communicate()
    finally:
        shutil.rmtree(temp_dir, ignore_errors=True)
    return proc.returncode, stdoutdata, stderrdata


def _test_cupy_available(self):
    returncode, stdoutdata, stderrdata = _run_script('''
import cupy
print(cupy.is_available())''')
    assert returncode == 0, 'stderr: {!r}'.format(stderrdata)
    assert stdoutdata in (b'True\n', b'True\r\n', b'False\n', b'False\r\n')
    return stdoutdata == b'True\n' or stdoutdata == b'True\r\n'


class TestImportError(unittest.TestCase):

    def test_import_error(self):
        returncode, stdoutdata, stderrdata = _run_script('''
try:
    import cupy
except Exception as e:
    print(type(e).__name__)
''')
        assert returncode == 0, 'stderr: {!r}'.format(stderrdata)
        assert stdoutdata in (b'', b'RuntimeError\n')


if not cupy.cuda.runtime.is_hip:
    visible = 'CUDA_VISIBLE_DEVICES'
else:
    visible = 'HIP_VISIBLE_DEVICES'


class TestAvailable(unittest.TestCase):

    def test_available(self):
        available = _test_cupy_available(self)
        assert available


class TestNotAvailable(unittest.TestCase):

    def setUp(self):
        self.old = os.environ.get(visible)

    def tearDown(self):
        if self.old is None:
            os.environ.pop(visible)
        else:
            os.environ[visible] = self.old

    @unittest.skipIf(cupy.cuda.runtime.is_hip,
                     'HIP handles empty HIP_VISIBLE_DEVICES differently')
    def test_no_device_1(self):
        os.environ['CUDA_VISIBLE_DEVICES'] = ' '
        available = _test_cupy_available(self)
        assert not available

    def test_no_device_2(self):
        os.environ[visible] = '-1'
        available = _test_cupy_available(self)
        assert not available


class TestMemoryPool(unittest.TestCase):

    def test_get_default_memory_pool(self):
        p = cupy.get_default_memory_pool()
        assert isinstance(p, cupy.cuda.memory.MemoryPool)

    def test_get_default_pinned_memory_pool(self):
        p = cupy.get_default_pinned_memory_pool()
        assert isinstance(p, cupy.cuda.pinned_memory.PinnedMemoryPool)


class TestShowConfig(unittest.TestCase):

    def test_show_config(self):
        with mock.patch('sys.stdout.write') as write_func:
            cupy.show_config()
        write_func.assert_called_once_with(
            str(cupyx.get_runtime_info(full=False)))

    def test_show_config_with_handles(self):
        with mock.patch('sys.stdout.write') as write_func:
            cupy.show_config(_full=True)
        write_func.assert_called_once_with(
            str(cupyx.get_runtime_info(full=True)))


class TestAliases(unittest.TestCase):

    def test_abs_is_absolute(self):
        for xp in (numpy, cupy):
            assert xp.abs is xp.absolute

    def test_conj_is_conjugate(self):
        for xp in (numpy, cupy):
            assert xp.conj is xp.conjugate

    def test_bitwise_not_is_invert(self):
        for xp in (numpy, cupy):
            assert xp.bitwise_not is xp.invert


@pytest.mark.parametrize('name', [
    'AxisError',
    'ComplexWarning',
    'ModuleDeprecationWarning',
    'RankWarning',
    'TooHardError',
    'VisibleDeprecationWarning',
    'linalg.LinAlgError'
])
def test_error_classes(name):
    get = operator.attrgetter(name)
    assert issubclass(get(cupy), get(numpy))


# This is copied from chainer/testing/__init__.py, so should be replaced in
# some way.
if __name__ == '__main__':
    import pytest
    pytest.main([__file__, '-vvs', '-x', '--pdb'])