test_sparse_emb.py 1.93 KB
Newer Older
1
import multiprocessing as mp
2
3
import os
import unittest
4
5

import backend as F
6
7
import pytest
import torch as th
8
9
10
11
12
13
14
15
16

from dgl.nn import NodeEmbedding


def initializer(emb):
    th.manual_seed(0)
    emb.uniform_(-1.0, 1.0)
    return emb

17

18
19
20
def check_all_set_all_get_func(device, init_emb):
    num_embs = init_emb.shape[0]
    emb_dim = init_emb.shape[1]
21
    dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device)
22
23
24
25
26
    dgl_emb.all_set_embedding(init_emb)

    out_emb = dgl_emb.all_get_embedding()
    assert F.allclose(init_emb, out_emb)

27

28
def start_sparse_worker(rank, world_size, test, args):
29
30
31
32
33
    print("start sparse worker {}".format(rank))
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip="127.0.0.1", master_port="12345"
    )
    backend = "gloo"
34
    device = F.ctx()
35
    if device.type == "cuda":
36
37
        device = th.device(rank)
        th.cuda.set_device(device)
38
39
40
41
42
43
    th.distributed.init_process_group(
        backend=backend,
        init_method=dist_init_method,
        world_size=world_size,
        rank=rank,
    )
44
45
46
47

    test(device, *args)
    th.distributed.barrier()

48
49

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
50
51
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set(num_workers):
52
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
53
54
55
56
57
58
        pytest.skip("Not enough GPUs to run test.")

    worker_list = []

    init_emb = th.rand([1000, 8])

59
    ctx = mp.get_context("spawn")
60
    for i in range(num_workers):
61
62
63
64
        p = ctx.Process(
            target=start_sparse_worker,
            args=(i, num_workers, check_all_set_all_get_func, (init_emb,)),
        )
65
66
67
68
69
70
71
72
73
        p.start()
        worker_list.append(p)

    for p in worker_list:
        p.join()
    for p in worker_list:
        assert p.exitcode == 0


74
if __name__ == "__main__":
75
76
77
    test_sparse_emb_get_set(1)
    test_sparse_emb_get_set(2)
    test_sparse_emb_get_set(3)