"vscode:/vscode.git/clone" did not exist on "e13b8f5c3616bdc58fa847a848d63acdd416a692"
test_nccl.py 1.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from dgl.cuda import nccl
from dgl.partition import NDArrayPartition
import unittest
import backend as F


def gen_test_id():
    return '{:0256x}'.format(78236728318467363)

@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_id():
    nccl_id = nccl.UniqueId()
    text = str(nccl_id)
    nccl_id2 = nccl.UniqueId(id_str=text)

    assert nccl_id == nccl_id2

    nccl_id2 = nccl.UniqueId(gen_test_id())

    assert nccl_id2 != nccl_id

    nccl_id3 = nccl.UniqueId(str(nccl_id2))

    assert nccl_id2 == nccl_id3


@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_push_single():
    nccl_id = nccl.UniqueId()
    comm = nccl.Communicator(1, 0, nccl_id)

    index = F.randint([10000], F.int32, F.ctx(), 0, 10000)
    value = F.uniform([10000, 100], F.float32, F.ctx(), -1.0, 1.0)

    part = NDArrayPartition(10000, 1, 'remainder')

    ri, rv = comm.sparse_all_to_all_push(index, value, part)
    assert F.array_equal(ri, index)
    assert F.array_equal(rv, value)

@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_pull_single():
    nccl_id = nccl.UniqueId()
    comm = nccl.Communicator(1, 0, nccl_id)

    req_index = F.randint([10000], F.int64, F.ctx(), 0, 100000)
    value = F.uniform([100000, 100], F.float32, F.ctx(), -1.0, 1.0)

    part = NDArrayPartition(100000, 1, 'remainder')

    rv = comm.sparse_all_to_all_pull(req_index, value, part)
    exp_rv = F.gather_row(value, req_index)
    assert F.array_equal(rv, exp_rv)


if __name__ == '__main__':
    test_nccl_id()
    test_nccl_sparse_push_single()
    test_nccl_sparse_pull_single()