test_dist_optim.py 5.07 KB
Newer Older
1
import os
2
3
4
5
6
7

os.environ["OMP_NUM_THREADS"] = "1"
import multiprocessing as mp
import pickle
import random
import socket
8
9
import sys
import time
10
11
12
13
import unittest

import backend as F
import numpy as np
14
import torch as th
15
from scipy import sparse as spsp
16

17
import dgl
18
from dgl import function as fn
19
20
21
22
23
24
25
from dgl.distributed import (
    DistEmbedding,
    DistGraph,
    DistGraphServer,
    load_partition_book,
    partition_graph,
)
26
27
from dgl.distributed.optim import SparseAdagrad, SparseAdam

28

29
def create_random_graph(n):
30
31
32
    arr = (
        spsp.random(n, n, density=0.001, format="coo", random_state=100) != 0
    ).astype(np.int64)
33
34
    return dgl.from_scipy(arr)

35

36
37
38
39
40
41
42
43
44
45
46
def get_local_usable_addr():
    """Get local usable IP and port

    Returns
    -------
    str
        IP address, e.g., '192.168.8.12:50051'
    """
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        # doesn't even have to be reachable
47
        sock.connect(("10.255.255.255", 1))
48
49
        ip_addr = sock.getsockname()[0]
    except ValueError:
50
        ip_addr = "127.0.0.1"
51
52
53
54
55
56
57
58
    finally:
        sock.close()
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(("", 0))
    sock.listen(1)
    port = sock.getsockname()[1]
    sock.close()

59
60
    return ip_addr + " " + str(port)

61
62
63
64

def prepare_dist():
    ip_config = open("optim_ip_config.txt", "w")
    ip_addr = get_local_usable_addr()
65
    ip_config.write("{}\n".format(ip_addr))
66
67
    ip_config.close()

68

69
def run_server(graph_name, server_id, server_count, num_clients, shared_mem):
70
71
72
73
74
75
76
77
78
    g = DistGraphServer(
        server_id,
        "optim_ip_config.txt",
        num_clients,
        server_count,
        "/tmp/dist_graph/{}.json".format(graph_name),
        disable_shared_mem=not shared_mem,
    )
    print("start server", server_id)
79
80
    g.start()

81

82
83
84
85
86
87
def initializer(shape, dtype):
    arr = th.zeros(shape, dtype=dtype)
    th.manual_seed(0)
    th.nn.init.uniform_(arr, 0, 1.0)
    return arr

88

89
def run_client(graph_name, cli_id, part_id, server_count):
90
    device = F.ctx()
91
    time.sleep(5)
92
    os.environ["DGL_NUM_SERVER"] = str(server_count)
93
    dgl.distributed.initialize("optim_ip_config.txt")
94
95
96
    gpb, graph_name, _, _ = load_partition_book(
        "/tmp/dist_graph/{}.json".format(graph_name), part_id, None
    )
97
    g = DistGraph(graph_name, gpb=gpb)
98
    policy = dgl.distributed.PartitionPolicy("node", g.get_partition_book())
99
100
    num_nodes = g.number_of_nodes()
    emb_dim = 4
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    dgl_emb = DistEmbedding(
        num_nodes,
        emb_dim,
        name="optim",
        init_func=initializer,
        part_policy=policy,
    )
    dgl_emb_zero = DistEmbedding(
        num_nodes,
        emb_dim,
        name="optim-zero",
        init_func=initializer,
        part_policy=policy,
    )
115
116
117
118
119
120
121
122
123
124
125
    dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
    dgl_adam._world_size = 1
    dgl_adam._rank = 0

    torch_emb = th.nn.Embedding(num_nodes, emb_dim, sparse=True)
    torch_emb_zero = th.nn.Embedding(num_nodes, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)
    torch_adam = th.optim.SparseAdam(
126
127
128
        list(torch_emb.parameters()) + list(torch_emb_zero.parameters()),
        lr=0.01,
    )
129
130
131

    labels = th.ones((4,)).long()
    idx = th.randint(0, num_nodes, size=(4,))
132
    dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
133
134
135
136
137
138
139
140
141
142
143
    torch_value = torch_emb(idx)
    torch_adam.zero_grad()
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    torch_loss.backward()
    torch_adam.step()

    dgl_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    dgl_loss.backward()
    dgl_adam.step()

144
145
146
147
    assert F.allclose(
        dgl_emb.weight[0 : num_nodes // 2], torch_emb.weight[0 : num_nodes // 2]
    )

148
149
150
151
152
153
154
155

def check_sparse_adam(num_trainer=1, shared_mem=True):
    prepare_dist()
    g = create_random_graph(2000)
    num_servers = num_trainer
    num_clients = num_trainer
    num_parts = 1

156
157
    graph_name = "dist_graph_test"
    partition_graph(g, graph_name, num_parts, "/tmp/dist_graph")
158
159
160
161

    # let's just test on one partition for now.
    # We cannot run multiple servers and clients on the same machine.
    serv_ps = []
162
    ctx = mp.get_context("spawn")
163
    for serv_id in range(num_servers):
164
165
166
167
        p = ctx.Process(
            target=run_server,
            args=(graph_name, serv_id, num_servers, num_clients, shared_mem),
        )
168
169
170
171
172
        serv_ps.append(p)
        p.start()

    cli_ps = []
    for cli_id in range(num_clients):
173
174
175
176
        print("start client", cli_id)
        p = ctx.Process(
            target=run_client, args=(graph_name, cli_id, 0, num_servers)
        )
177
178
179
180
181
182
183
184
185
        p.start()
        cli_ps.append(p)

    for p in cli_ps:
        p.join()

    for p in serv_ps:
        p.join()

186
187

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
188
def test_sparse_opt():
189
    os.environ["DGL_DIST_MODE"] = "distributed"
190
191
192
    check_sparse_adam(1, True)
    check_sparse_adam(1, False)

193
194
195
196

if __name__ == "__main__":
    os.makedirs("/tmp/dist_graph", exist_ok=True)
    test_sparse_opt()