".ci/check-python-dists.sh" did not exist on "cc733f8595267f313886f92ed5d1285010ba8f3f"
test_nccl.py 3.1 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import pickle
import unittest

import cupy
from cupy import cuda
from cupy.cuda import nccl
from cupy import testing


nccl_available = nccl.available

if nccl_available:
    nccl_version = nccl.get_version()
else:
    nccl_version = -1


@unittest.skipUnless(nccl_available, 'nccl is not installed')
class TestNCCL(unittest.TestCase):

    def test_single_proc_ring(self):
        id = nccl.get_unique_id()
        comm = nccl.NcclCommunicator(1, id, 0)
        assert 0 == comm.rank_id()
        comm.destroy()

    @unittest.skipUnless(nccl_version >= 2400, 'Using old NCCL')
    def test_abort(self):
        id = nccl.get_unique_id()
        comm = nccl.NcclCommunicator(1, id, 0)
        comm.abort()

    @unittest.skipUnless(nccl_version >= 2400, 'Using old NCCL')
    def test_check_async_error(self):
        id = nccl.get_unique_id()
        comm = nccl.NcclCommunicator(1, id, 0)
        comm.check_async_error()
        comm.destroy()

    def test_init_all(self):
        comms = nccl.NcclCommunicator.initAll(1)
        for i, comm in enumerate(comms):
            assert i == comms[i].rank_id()
        for i, comm in enumerate(comms):
            comms[i].destroy()

    def test_single_proc_single_dev(self):
        comms = nccl.NcclCommunicator.initAll(1)
        nccl.groupStart()
        for comm in comms:
            cuda.Device(comm.device_id()).use()
            sendbuf = cupy.arange(10)
            recvbuf = cupy.zeros_like(sendbuf)
            comm.allReduce(sendbuf.data.ptr, recvbuf.data.ptr, 10,
                           nccl.NCCL_INT64, nccl.NCCL_SUM,
                           cuda.Stream.null.ptr)
        nccl.groupEnd()
        assert cupy.allclose(sendbuf, recvbuf)

    def test_comm_size(self):
        id = nccl.get_unique_id()
        comm = nccl.NcclCommunicator(1, id, 0)
        assert 1 == comm.size()

    @testing.multi_gpu(2)
    @unittest.skipUnless(nccl_version >= 2700, 'Using old NCCL')
    def test_send_recv(self):
        devs = [0, 1]
        comms = nccl.NcclCommunicator.initAll(devs)
        nccl.groupStart()
        for comm in comms:
            dev_id = comm.device_id()
            rank = comm.rank_id()
            assert rank == dev_id

            if rank == 0:
                with cuda.Device(dev_id):
                    sendbuf = cupy.arange(10, dtype=cupy.int64)
                    comm.send(sendbuf.data.ptr, 10, nccl.NCCL_INT64,
                              1, cuda.Stream.null.ptr)
            elif rank == 1:
                with cuda.Device(dev_id):
                    recvbuf = cupy.zeros(10, dtype=cupy.int64)
                    comm.recv(recvbuf.data.ptr, 10, nccl.NCCL_INT64,
                              0, cuda.Stream.null.ptr)
        nccl.groupEnd()

        # check result
        with cuda.Device(1):
            expected = cupy.arange(10, dtype=cupy.int64)
            assert (recvbuf == expected).all()


@unittest.skipUnless(nccl_available, 'nccl is not installed')
class TestExceptionPicklable(unittest.TestCase):

    def test(self):
        e1 = nccl.NcclError(1)
        e2 = pickle.loads(pickle.dumps(e1))
        assert e1.args == e2.args
        assert str(e1) == str(e2)