test_sparse_emb.py 3.71 KB
Newer Older
1
import multiprocessing as mp
2
3
import os
import unittest
4
5

import backend as F
6
7
import pytest
import torch as th
8
9

from dgl.nn import NodeEmbedding
10
from dgl.optim import SparseAdam
11
12
13
14
15
16
17


def initializer(emb):
    th.manual_seed(0)
    emb.uniform_(-1.0, 1.0)
    return emb

18

19
def check_all_set_all_get_emb(device, init_emb):
20
21
    num_embs = init_emb.shape[0]
    emb_dim = init_emb.shape[1]
22
    dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device)
23
24
25
26
27
    dgl_emb.all_set_embedding(init_emb)

    out_emb = dgl_emb.all_get_embedding()
    assert F.allclose(init_emb, out_emb)

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def check_all_set_all_get_optm_state(
    device, state_step, state_mem, state_power
):
    num_embs = state_mem.shape[0]
    emb_dim = state_mem.shape[1]
    dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device)
    optm = SparseAdam(params=[dgl_emb], lr=0.01)

    dgl_emb._all_set_optm_state((state_step, state_mem, state_power))

    out_step, out_mem, out_power = dgl_emb._all_get_optm_state()

    assert F.allclose(state_step, out_step)
    assert F.allclose(state_mem, out_mem)
    assert F.allclose(state_power, out_power)


46
def start_sparse_worker(rank, world_size, test, args):
47
48
49
50
51
    print("start sparse worker {}".format(rank))
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip="127.0.0.1", master_port="12345"
    )
    backend = "gloo"
52
    device = F.ctx()
53
    if device.type == "cuda":
54
55
        device = th.device(rank)
        th.cuda.set_device(device)
56
57
58
59
60
61
    th.distributed.init_process_group(
        backend=backend,
        init_method=dist_init_method,
        world_size=world_size,
        rank=rank,
    )
62
63
64

    test(device, *args)
    th.distributed.barrier()
65
    th.distributed.destroy_process_group()
66

67
68

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
69
70
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set(num_workers):
71
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
72
73
74
75
76
77
        pytest.skip("Not enough GPUs to run test.")

    worker_list = []

    init_emb = th.rand([1000, 8])

78
    ctx = mp.get_context("spawn")
79
    for i in range(num_workers):
80
81
        p = ctx.Process(
            target=start_sparse_worker,
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            args=(i, num_workers, check_all_set_all_get_emb, (init_emb,)),
        )
        p.start()
        worker_list.append(p)

    for p in worker_list:
        p.join()
    for p in worker_list:
        assert p.exitcode == 0


@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set_optm_state(num_workers):
    if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
        pytest.skip("Not enough GPUs to run test.")

    worker_list = []

    num_embs, emb_dim = 1000, 8
    state_step = th.randint(1000, (num_embs,))
    state_mem = th.rand((num_embs, emb_dim))
    state_power = th.rand((num_embs, emb_dim))

    ctx = mp.get_context("spawn")
    for i in range(num_workers):
        p = ctx.Process(
            target=start_sparse_worker,
            args=(
                i,
                num_workers,
                check_all_set_all_get_optm_state,
                (state_step, state_mem, state_power),
            ),
116
        )
117
118
119
120
121
122
123
124
125
        p.start()
        worker_list.append(p)

    for p in worker_list:
        p.join()
    for p in worker_list:
        assert p.exitcode == 0


126
if __name__ == "__main__":
127
128
129
130
131
132
133
    # test_multiprocess_sparse_emb_get_set(1)
    # test_multiprocess_sparse_emb_get_set(2)
    # test_multiprocess_sparse_emb_get_set(3)

    test_multiprocess_sparse_emb_get_set_optm_state(1)
    # test_multiprocess_sparse_emb_get_set_optm_state(2)
    # test_multiprocess_sparse_emb_get_set_optm_state(3)