test_comm.py 5.44 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import pathlib
import subprocess
import sys
import unittest
import os

import numpy
import pytest

from cupy.cuda import nccl
from cupy import testing

from cupyx.distributed import init_process_group
from cupyx.distributed._nccl_comm import _mpi_available


nccl_available = nccl.available


def _run_test(test_name, dtype=None):
    # subprocess is required not to interfere with cupy module imported in top
    # of this file
    runner_path = pathlib.Path(__file__).parent / 'comm_runner.py'
    args = [sys.executable, runner_path, test_name, 'store']
    if dtype is not None:
        args.append(numpy.dtype(dtype).char)
    proc = subprocess.Popen(
        args,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)
    stdoutdata, stderrdata = proc.communicate()
    assert stderrdata.decode() == ''
    assert proc.returncode == 0


def _run_test_with_mpi(test_name, dtype=None):
    # subprocess is required not to interfere with cupy module imported in top
    # of this file
    runner_path = pathlib.Path(__file__).parent / 'comm_runner.py'
    args = ['mpiexec', '-n', '2', '--allow-run-as-root',
            sys.executable, runner_path, test_name, 'mpi']
    if dtype is not None:
        args.append(numpy.dtype(dtype).char)
    proc = subprocess.Popen(
        args,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        env=os.environ
    )
    stdoutdata, stderrdata = proc.communicate()
    assert stderrdata.decode() == ''
    assert proc.returncode == 0


@pytest.mark.skipif(not nccl_available, reason='nccl is not installed')
@testing.multi_gpu(2)
class TestNCCLBackend:
    def _run_test(self, test, dtype):
        _run_test(test, dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_broadcast(self, dtype):
        self._run_test('broadcast', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_reduce(self, dtype):
        self._run_test('reduce', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_all_reduce(self, dtype):
        self._run_test('all_reduce', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_reduce_scatter(self, dtype):
        self._run_test('reduce_scatter', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_all_gather(self, dtype):
        self._run_test('all_gather', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_send_and_recv(self, dtype):
        self._run_test('send_and_recv', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_send_recv(self, dtype):
        self._run_test('send_recv', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_scatter(self, dtype):
        self._run_test('scatter', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_gather(self, dtype):
        self._run_test('gather', dtype)

    @testing.for_all_dtypes(no_bool=True)
    def test_all_to_all(self, dtype):
        self._run_test('all_to_all', dtype)

    def test_barrier(self):
        self._run_test('barrier', None)


@pytest.mark.skipif(not _mpi_available, reason='mpi is not installed')
@testing.multi_gpu(2)
class TestNCCLBackendWithMPI(TestNCCLBackend):
    def _run_test(self, test, dtype):
        _run_test_with_mpi(test, dtype)


@pytest.mark.skipif(not nccl_available, reason='nccl is not installed')
@testing.multi_gpu(2)
class TestNCCLBackendSparse:
    def _run_test(self, test, dtype):
        _run_test(test, dtype)

    @testing.for_dtypes('fdFD')
    def test_send_and_recv(self, dtype):
        self._run_test('sparse_send_and_recv', dtype)

    @testing.for_dtypes('fdFD')
    def test_broadcast(self, dtype):
        self._run_test('sparse_broadcast', dtype)

    @testing.for_dtypes('fdFD')
    def test_reduce(self, dtype):
        self._run_test('sparse_reduce', dtype)

    @testing.for_dtypes('fdFD')
    def test_all_reduce(self, dtype):
        self._run_test('sparse_all_reduce', dtype)

    @testing.for_dtypes('fdFD')
    def test_scatter(self, dtype):
        self._run_test('sparse_scatter', dtype)

    @testing.for_dtypes('fdFD')
    def test_gather(self, dtype):
        self._run_test('sparse_gather', dtype)

    @testing.for_dtypes('fdFD')
    def test_all_gather(self, dtype):
        self._run_test('sparse_all_gather', dtype)

    @testing.for_dtypes('fdFD')
    def test_all_to_all(self, dtype):
        self._run_test('sparse_all_to_all', dtype)

    @testing.for_dtypes('fdFD')
    def test_reduce_scatter(self, dtype):
        self._run_test('sparse_reduce_scatter', dtype)

    @testing.for_dtypes('fdFD')
    def test_send_recv(self, dtype):
        self._run_test('sparse_send_recv', dtype)


@pytest.mark.skipif(not _mpi_available, reason='mpi is not installed')
@testing.multi_gpu(2)
class TestNCCLBackendSparseWithMPI(TestNCCLBackendSparse):
    def _run_test(self, test, dtype):
        _run_test_with_mpi(test, dtype)


@pytest.mark.skipif(not nccl_available, reason='nccl is not installed')
class TestInitDistributed(unittest.TestCase):

    @testing.multi_gpu(2)
    def test_init(self):
        _run_test('init')

    def test_invalid_backend(self):
        with pytest.raises(ValueError):
            init_process_group(1, 0, backend='mpi')

    def test_invalid_n_devices(self):
        with pytest.raises(ValueError):
            init_process_group(0, 0)

        with pytest.raises(ValueError):
            init_process_group(-1, 0)

    def test_invalid_rank(self):
        with pytest.raises(ValueError):
            init_process_group(2, -1)

        with pytest.raises(ValueError):
            init_process_group(2, 3)